MLIR  20.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 #define DEBUG_TYPE "spirv-to-llvm-pattern"
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Constants
33 //===----------------------------------------------------------------------===//
34 
35 constexpr unsigned defaultAddressSpace = 0;
36 
37 //===----------------------------------------------------------------------===//
38 // Utility functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given type is a signed integer or vector type.
42 static bool isSignedIntegerOrVector(Type type) {
43  if (type.isSignedInteger())
44  return true;
45  if (auto vecType = dyn_cast<VectorType>(type))
46  return vecType.getElementType().isSignedInteger();
47  return false;
48 }
49 
50 /// Returns true if the given type is an unsigned integer or vector type
51 static bool isUnsignedIntegerOrVector(Type type) {
52  if (type.isUnsignedInteger())
53  return true;
54  if (auto vecType = dyn_cast<VectorType>(type))
55  return vecType.getElementType().isUnsignedInteger();
56  return false;
57 }
58 
59 /// Returns the width of an integer or of the element type of an integer vector,
60 /// if applicable.
61 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
62  if (auto intType = dyn_cast<IntegerType>(type))
63  return intType.getWidth();
64  if (auto vecType = dyn_cast<VectorType>(type))
65  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
66  return intType.getWidth();
67  return std::nullopt;
68 }
69 
70 /// Returns the bit width of integer, float or vector of float or integer values
71 static unsigned getBitWidth(Type type) {
72  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
73  "bitwidth is not supported for this type");
74  if (type.isIntOrFloat())
75  return type.getIntOrFloatBitWidth();
76  auto vecType = dyn_cast<VectorType>(type);
77  auto elementType = vecType.getElementType();
78  assert(elementType.isIntOrFloat() &&
79  "only integers and floats have a bitwidth");
80  return elementType.getIntOrFloatBitWidth();
81 }
82 
83 /// Returns the bit width of LLVMType integer or vector.
84 static unsigned getLLVMTypeBitWidth(Type type) {
85  return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
87  : type))
88  .getWidth();
89 }
90 
91 /// Creates `IntegerAttribute` with all bits set for given type
92 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
93  if (auto vecType = dyn_cast<VectorType>(type)) {
94  auto integerType = cast<IntegerType>(vecType.getElementType());
95  return builder.getIntegerAttr(integerType, -1);
96  }
97  auto integerType = cast<IntegerType>(type);
98  return builder.getIntegerAttr(integerType, -1);
99 }
100 
101 /// Creates `llvm.mlir.constant` with all bits set for the given type.
102 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
103  PatternRewriter &rewriter) {
104  if (isa<VectorType>(srcType)) {
105  return rewriter.create<LLVM::ConstantOp>(
106  loc, dstType,
107  SplatElementsAttr::get(cast<ShapedType>(srcType),
108  minusOneIntegerAttribute(srcType, rewriter)));
109  }
110  return rewriter.create<LLVM::ConstantOp>(
111  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
112 }
113 
114 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
115 static Value createFPConstant(Location loc, Type srcType, Type dstType,
116  PatternRewriter &rewriter, double value) {
117  if (auto vecType = dyn_cast<VectorType>(srcType)) {
118  auto floatType = cast<FloatType>(vecType.getElementType());
119  return rewriter.create<LLVM::ConstantOp>(
120  loc, dstType,
121  SplatElementsAttr::get(vecType,
122  rewriter.getFloatAttr(floatType, value)));
123  }
124  auto floatType = cast<FloatType>(srcType);
125  return rewriter.create<LLVM::ConstantOp>(
126  loc, dstType, rewriter.getFloatAttr(floatType, value));
127 }
128 
129 /// Utility function for bitfield ops:
130 /// - `BitFieldInsert`
131 /// - `BitFieldSExtract`
132 /// - `BitFieldUExtract`
133 /// Truncates or extends the value. If the bitwidth of the value is the same as
134 /// `llvmType` bitwidth, the value remains unchanged.
136  Type llvmType,
137  PatternRewriter &rewriter) {
138  auto srcType = value.getType();
139  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
140  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
141  ? getLLVMTypeBitWidth(srcType)
142  : getBitWidth(srcType);
143 
144  if (valueBitWidth < targetBitWidth)
145  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
146  // If the bit widths of `Count` and `Offset` are greater than the bit width
147  // of the target type, they are truncated. Truncation is safe since `Count`
148  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
149  // both values can be expressed in 8 bits.
150  if (valueBitWidth > targetBitWidth)
151  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
152  return value;
153 }
154 
155 /// Broadcasts the value to vector with `numElements` number of elements.
156 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
157  LLVMTypeConverter &typeConverter,
158  ConversionPatternRewriter &rewriter) {
159  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
160  auto llvmVectorType = typeConverter.convertType(vectorType);
161  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
162  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
163  for (unsigned i = 0; i < numElements; ++i) {
164  auto index = rewriter.create<LLVM::ConstantOp>(
165  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
166  broadcasted = rewriter.create<LLVM::InsertElementOp>(
167  loc, llvmVectorType, broadcasted, toBroadcast, index);
168  }
169  return broadcasted;
170 }
171 
172 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
173 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
174  LLVMTypeConverter &typeConverter,
175  ConversionPatternRewriter &rewriter) {
176  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
177  unsigned numElements = vectorType.getNumElements();
178  return broadcast(loc, value, numElements, typeConverter, rewriter);
179  }
180  return value;
181 }
182 
183 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
184 /// `BitFieldUExtract`.
185 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
186 /// a vector type, construct a vector that has:
187 /// - same number of elements as `Base`
188 /// - each element has the type that is the same as the type of `Offset` or
189 /// `Count`
190 /// - each element has the same value as `Offset` or `Count`
191 /// Then cast `Offset` and `Count` if their bit width is different
192 /// from `Base` bit width.
193 static Value processCountOrOffset(Location loc, Value value, Type srcType,
194  Type dstType, LLVMTypeConverter &converter,
195  ConversionPatternRewriter &rewriter) {
196  Value broadcasted =
197  optionallyBroadcast(loc, value, srcType, converter, rewriter);
198  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
199 }
200 
201 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
202 /// offset to LLVM struct. Otherwise, the conversion is not supported.
204  LLVMTypeConverter &converter) {
205  if (type != VulkanLayoutUtils::decorateType(type))
206  return nullptr;
207 
208  SmallVector<Type> elementsVector;
209  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
210  return nullptr;
211  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
212  /*isPacked=*/false);
213 }
214 
215 /// Converts SPIR-V struct with no offset to packed LLVM struct.
217  LLVMTypeConverter &converter) {
218  SmallVector<Type> elementsVector;
219  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
220  return nullptr;
221  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
222  /*isPacked=*/true);
223 }
224 
225 /// Creates LLVM dialect constant with the given value.
227  unsigned value) {
228  return rewriter.create<LLVM::ConstantOp>(
229  loc, IntegerType::get(rewriter.getContext(), 32),
230  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
231 }
232 
233 /// Utility for `spirv.Load` and `spirv.Store` conversion.
234 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
235  ConversionPatternRewriter &rewriter,
236  LLVMTypeConverter &typeConverter,
237  unsigned alignment, bool isVolatile,
238  bool isNonTemporal) {
239  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
240  auto dstType = typeConverter.convertType(loadOp.getType());
241  if (!dstType)
242  return rewriter.notifyMatchFailure(op, "type conversion failed");
243  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
244  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
245  isVolatile, isNonTemporal);
246  return success();
247  }
248  auto storeOp = cast<spirv::StoreOp>(op);
249  spirv::StoreOpAdaptor adaptor(operands);
250  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
251  adaptor.getPtr(), alignment,
252  isVolatile, isNonTemporal);
253  return success();
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // Type conversion
258 //===----------------------------------------------------------------------===//
259 
260 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
261 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
262 /// when converting ops that manipulate array types.
263 static std::optional<Type> convertArrayType(spirv::ArrayType type,
264  TypeConverter &converter) {
265  unsigned stride = type.getArrayStride();
266  Type elementType = type.getElementType();
267  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
268  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
269  return std::nullopt;
270 
271  auto llvmElementType = converter.convertType(elementType);
272  unsigned numElements = type.getNumElements();
273  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
274 }
275 
276 static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) {
277  // Based on
278  // https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form
279  // and clang/lib/Basic/Targets/SPIR.h.
280  switch (storageClass) {
281 #define STORAGE_SPACE_MAP(storage, space) \
282  case spirv::StorageClass::storage: \
283  return space;
284  STORAGE_SPACE_MAP(Function, 0)
285  STORAGE_SPACE_MAP(CrossWorkgroup, 1)
286  STORAGE_SPACE_MAP(Input, 1)
287  STORAGE_SPACE_MAP(UniformConstant, 2)
288  STORAGE_SPACE_MAP(Workgroup, 3)
289  STORAGE_SPACE_MAP(Generic, 4)
290  STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5)
291  STORAGE_SPACE_MAP(HostOnlyINTEL, 6)
292 #undef STORAGE_SPACE_MAP
293  default:
294  return defaultAddressSpace;
295  }
296 }
297 
298 static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI,
299  spirv::StorageClass storageClass) {
300  switch (clientAPI) {
301 #define CLIENT_MAP(client, storage) \
302  case spirv::ClientAPI::client: \
303  return mapTo##client##AddressSpace(storage);
304  CLIENT_MAP(OpenCL, storageClass)
305 #undef CLIENT_MAP
306  default:
307  return defaultAddressSpace;
308  }
309 }
310 
311 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
312 /// modelled at the moment.
314  LLVMTypeConverter &converter,
315  spirv::ClientAPI clientAPI) {
316  unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass());
317  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
318 }
319 
320 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
321 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
322 /// no modelling of array stride at the moment.
323 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
324  TypeConverter &converter) {
325  if (type.getArrayStride() != 0)
326  return std::nullopt;
327  auto elementType = converter.convertType(type.getElementType());
328  return LLVM::LLVMArrayType::get(elementType, 0);
329 }
330 
331 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
332 /// member decorations. Also, only natural offset is supported.
334  LLVMTypeConverter &converter) {
336  type.getMemberDecorations(memberDecorations);
337  if (!memberDecorations.empty())
338  return nullptr;
339  if (type.hasOffset())
340  return convertStructTypeWithOffset(type, converter);
341  return convertStructTypePacked(type, converter);
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // Operation conversion
346 //===----------------------------------------------------------------------===//
347 
348 namespace {
349 
350 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
351 public:
353 
354  LogicalResult
355  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
356  ConversionPatternRewriter &rewriter) const override {
357  auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
358  if (!dstType)
359  return rewriter.notifyMatchFailure(op, "type conversion failed");
360  // To use GEP we need to add a first 0 index to go through the pointer.
361  auto indices = llvm::to_vector<4>(adaptor.getIndices());
362  Type indexType = op.getIndices().front().getType();
363  auto llvmIndexType = typeConverter.convertType(indexType);
364  if (!llvmIndexType)
365  return rewriter.notifyMatchFailure(op, "type conversion failed");
366  Value zero = rewriter.create<LLVM::ConstantOp>(
367  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
368  indices.insert(indices.begin(), zero);
369 
370  auto elementType = typeConverter.convertType(
371  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
372  if (!elementType)
373  return rewriter.notifyMatchFailure(op, "type conversion failed");
374  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
375  adaptor.getBasePtr(), indices);
376  return success();
377  }
378 };
379 
380 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
381 public:
383 
384  LogicalResult
385  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
386  ConversionPatternRewriter &rewriter) const override {
387  auto dstType = typeConverter.convertType(op.getPointer().getType());
388  if (!dstType)
389  return rewriter.notifyMatchFailure(op, "type conversion failed");
390  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
391  op.getVariable());
392  return success();
393  }
394 };
395 
396 class BitFieldInsertPattern
397  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
398 public:
400 
401  LogicalResult
402  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
403  ConversionPatternRewriter &rewriter) const override {
404  auto srcType = op.getType();
405  auto dstType = typeConverter.convertType(srcType);
406  if (!dstType)
407  return rewriter.notifyMatchFailure(op, "type conversion failed");
408  Location loc = op.getLoc();
409 
410  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
411  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
412  typeConverter, rewriter);
413  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
414  typeConverter, rewriter);
415 
416  // Create a mask with bits set outside [Offset, Offset + Count - 1].
417  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
418  Value maskShiftedByCount =
419  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
420  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
421  maskShiftedByCount, minusOne);
422  Value maskShiftedByCountAndOffset =
423  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
424  Value mask = rewriter.create<LLVM::XOrOp>(
425  loc, dstType, maskShiftedByCountAndOffset, minusOne);
426 
427  // Extract unchanged bits from the `Base` that are outside of
428  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
429  Value baseAndMask =
430  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
431  Value insertShiftedByOffset =
432  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
433  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
434  insertShiftedByOffset);
435  return success();
436  }
437 };
438 
439 /// Converts SPIR-V ConstantOp with scalar or vector type.
440 class ConstantScalarAndVectorPattern
441  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
442 public:
444 
445  LogicalResult
446  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
447  ConversionPatternRewriter &rewriter) const override {
448  auto srcType = constOp.getType();
449  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
450  return failure();
451 
452  auto dstType = typeConverter.convertType(srcType);
453  if (!dstType)
454  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
455 
456  // SPIR-V constant can be a signed/unsigned integer, which has to be
457  // casted to signless integer when converting to LLVM dialect. Removing the
458  // sign bit may have unexpected behaviour. However, it is better to handle
459  // it case-by-case, given that the purpose of the conversion is not to
460  // cover all possible corner cases.
461  if (isSignedIntegerOrVector(srcType) ||
462  isUnsignedIntegerOrVector(srcType)) {
463  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
464 
465  if (isa<VectorType>(srcType)) {
466  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
467  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
468  constOp, dstType,
469  dstElementsAttr.mapValues(
470  signlessType, [&](const APInt &value) { return value; }));
471  return success();
472  }
473  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
474  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
475  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
476  return success();
477  }
478  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
479  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
480  return success();
481  }
482 };
483 
484 class BitFieldSExtractPattern
485  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
486 public:
488 
489  LogicalResult
490  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
491  ConversionPatternRewriter &rewriter) const override {
492  auto srcType = op.getType();
493  auto dstType = typeConverter.convertType(srcType);
494  if (!dstType)
495  return rewriter.notifyMatchFailure(op, "type conversion failed");
496  Location loc = op.getLoc();
497 
498  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
499  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
500  typeConverter, rewriter);
501  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
502  typeConverter, rewriter);
503 
504  // Create a constant that holds the size of the `Base`.
505  IntegerType integerType;
506  if (auto vecType = dyn_cast<VectorType>(srcType))
507  integerType = cast<IntegerType>(vecType.getElementType());
508  else
509  integerType = cast<IntegerType>(srcType);
510 
511  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
512  Value size =
513  isa<VectorType>(srcType)
514  ? rewriter.create<LLVM::ConstantOp>(
515  loc, dstType,
516  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
517  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
518 
519  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
520  // at Offset + Count - 1 is the most significant bit now.
521  Value countPlusOffset =
522  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
523  Value amountToShiftLeft =
524  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
525  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
526  loc, dstType, op.getBase(), amountToShiftLeft);
527 
528  // Shift the result right, filling the bits with the sign bit.
529  Value amountToShiftRight =
530  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
531  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
532  amountToShiftRight);
533  return success();
534  }
535 };
536 
537 class BitFieldUExtractPattern
538  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
539 public:
541 
542  LogicalResult
543  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
544  ConversionPatternRewriter &rewriter) const override {
545  auto srcType = op.getType();
546  auto dstType = typeConverter.convertType(srcType);
547  if (!dstType)
548  return rewriter.notifyMatchFailure(op, "type conversion failed");
549  Location loc = op.getLoc();
550 
551  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
552  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
553  typeConverter, rewriter);
554  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
555  typeConverter, rewriter);
556 
557  // Create a mask with bits set at [0, Count - 1].
558  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
559  Value maskShiftedByCount =
560  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
561  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
562  minusOne);
563 
564  // Shift `Base` by `Offset` and apply the mask on it.
565  Value shiftedBase =
566  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
567  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
568  return success();
569  }
570 };
571 
572 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
573 public:
575 
576  LogicalResult
577  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
578  ConversionPatternRewriter &rewriter) const override {
579  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
580  branchOp.getTarget());
581  return success();
582  }
583 };
584 
585 class BranchConditionalConversionPattern
586  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
587 public:
588  using SPIRVToLLVMConversion<
589  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
590 
591  LogicalResult
592  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
593  ConversionPatternRewriter &rewriter) const override {
594  // If branch weights exist, map them to 32-bit integer vector.
595  DenseI32ArrayAttr branchWeights = nullptr;
596  if (auto weights = op.getBranchWeights()) {
597  SmallVector<int32_t> weightValues;
598  for (auto weight : weights->getAsRange<IntegerAttr>())
599  weightValues.push_back(weight.getInt());
600  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
601  }
602 
603  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
604  op, op.getCondition(), op.getTrueBlockArguments(),
605  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
606  op.getFalseBlock());
607  return success();
608  }
609 };
610 
611 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
612 /// type is an aggregate type (struct or array). Otherwise, converts to
613 /// `llvm.extractelement` that operates on vectors.
614 class CompositeExtractPattern
615  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
616 public:
618 
619  LogicalResult
620  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
621  ConversionPatternRewriter &rewriter) const override {
622  auto dstType = this->typeConverter.convertType(op.getType());
623  if (!dstType)
624  return rewriter.notifyMatchFailure(op, "type conversion failed");
625 
626  Type containerType = op.getComposite().getType();
627  if (isa<VectorType>(containerType)) {
628  Location loc = op.getLoc();
629  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
630  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
631  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
632  op, dstType, adaptor.getComposite(), index);
633  return success();
634  }
635 
636  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
637  op, adaptor.getComposite(),
638  LLVM::convertArrayToIndices(op.getIndices()));
639  return success();
640  }
641 };
642 
643 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
644 /// type is an aggregate type (struct or array). Otherwise, converts to
645 /// `llvm.insertelement` that operates on vectors.
646 class CompositeInsertPattern
647  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
648 public:
650 
651  LogicalResult
652  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
653  ConversionPatternRewriter &rewriter) const override {
654  auto dstType = this->typeConverter.convertType(op.getType());
655  if (!dstType)
656  return rewriter.notifyMatchFailure(op, "type conversion failed");
657 
658  Type containerType = op.getComposite().getType();
659  if (isa<VectorType>(containerType)) {
660  Location loc = op.getLoc();
661  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
662  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
663  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
664  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
665  return success();
666  }
667 
668  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
669  op, adaptor.getComposite(), adaptor.getObject(),
670  LLVM::convertArrayToIndices(op.getIndices()));
671  return success();
672  }
673 };
674 
675 /// Converts SPIR-V operations that have straightforward LLVM equivalent
676 /// into LLVM dialect operations.
677 template <typename SPIRVOp, typename LLVMOp>
678 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
679 public:
681 
682  LogicalResult
683  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
684  ConversionPatternRewriter &rewriter) const override {
685  auto dstType = this->typeConverter.convertType(op.getType());
686  if (!dstType)
687  return rewriter.notifyMatchFailure(op, "type conversion failed");
688  rewriter.template replaceOpWithNewOp<LLVMOp>(
689  op, dstType, adaptor.getOperands(), op->getAttrs());
690  return success();
691  }
692 };
693 
694 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
695 /// execution mode information.
696 class ExecutionModePattern
697  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
698 public:
700 
701  LogicalResult
702  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
703  ConversionPatternRewriter &rewriter) const override {
704  // First, create the global struct's name that would be associated with
705  // this entry point's execution mode. We set it to be:
706  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
707  ModuleOp module = op->getParentOfType<ModuleOp>();
708  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
709  std::string moduleName;
710  if (module.getName().has_value())
711  moduleName = "_" + module.getName()->str();
712  else
713  moduleName = "";
714  std::string executionModeInfoName = llvm::formatv(
715  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
716  static_cast<uint32_t>(executionModeAttr.getValue()));
717 
718  MLIRContext *context = rewriter.getContext();
719  OpBuilder::InsertionGuard guard(rewriter);
720  rewriter.setInsertionPointToStart(module.getBody());
721 
722  // Create a struct type, corresponding to the C struct below.
723  // struct {
724  // int32_t executionMode;
725  // int32_t values[]; // optional values
726  // };
727  auto llvmI32Type = IntegerType::get(context, 32);
728  SmallVector<Type, 2> fields;
729  fields.push_back(llvmI32Type);
730  ArrayAttr values = op.getValues();
731  if (!values.empty()) {
732  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
733  fields.push_back(arrayType);
734  }
735  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
736 
737  // Create `llvm.mlir.global` with initializer region containing one block.
738  auto global = rewriter.create<LLVM::GlobalOp>(
739  UnknownLoc::get(context), structType, /*isConstant=*/true,
740  LLVM::Linkage::External, executionModeInfoName, Attribute(),
741  /*alignment=*/0);
742  Location loc = global.getLoc();
743  Region &region = global.getInitializerRegion();
744  Block *block = rewriter.createBlock(&region);
745 
746  // Initialize the struct and set the execution mode value.
747  rewriter.setInsertionPoint(block, block->begin());
748  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
749  Value executionMode = rewriter.create<LLVM::ConstantOp>(
750  loc, llvmI32Type,
751  rewriter.getI32IntegerAttr(
752  static_cast<uint32_t>(executionModeAttr.getValue())));
753  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
754  executionMode, 0);
755 
756  // Insert extra operands if they exist into execution mode info struct.
757  for (unsigned i = 0, e = values.size(); i < e; ++i) {
758  auto attr = values.getValue()[i];
759  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
760  structValue = rewriter.create<LLVM::InsertValueOp>(
761  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
762  }
763  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
764  rewriter.eraseOp(op);
765  return success();
766  }
767 };
768 
769 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
770 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
771 /// value. This difference is handled by `spirv.mlir.addressof` and
772 /// `llvm.mlir.addressof`ops that both return a pointer.
773 class GlobalVariablePattern
774  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
775 public:
776  template <typename... Args>
777  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
778  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
779  std::forward<Args>(args)...),
780  clientAPI(clientAPI) {}
781 
782  LogicalResult
783  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
784  ConversionPatternRewriter &rewriter) const override {
785  // Currently, there is no support of initialization with a constant value in
786  // SPIR-V dialect. Specialization constants are not considered as well.
787  if (op.getInitializer())
788  return failure();
789 
790  auto srcType = cast<spirv::PointerType>(op.getType());
791  auto dstType = typeConverter.convertType(srcType.getPointeeType());
792  if (!dstType)
793  return rewriter.notifyMatchFailure(op, "type conversion failed");
794 
795  // Limit conversion to the current invocation only or `StorageBuffer`
796  // required by SPIR-V runner.
797  // This is okay because multiple invocations are not supported yet.
798  auto storageClass = srcType.getStorageClass();
799  switch (storageClass) {
800  case spirv::StorageClass::Input:
801  case spirv::StorageClass::Private:
802  case spirv::StorageClass::Output:
803  case spirv::StorageClass::StorageBuffer:
804  case spirv::StorageClass::UniformConstant:
805  break;
806  default:
807  return failure();
808  }
809 
810  // LLVM dialect spec: "If the global value is a constant, storing into it is
811  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
812  // storage class that is read-only.
813  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
814  (storageClass == spirv::StorageClass::UniformConstant);
815  // SPIR-V spec: "By default, functions and global variables are private to a
816  // module and cannot be accessed by other modules. However, a module may be
817  // written to export or import functions and global (module scope)
818  // variables.". Therefore, map 'Private' storage class to private linkage,
819  // 'Input' and 'Output' to external linkage.
820  auto linkage = storageClass == spirv::StorageClass::Private
821  ? LLVM::Linkage::Private
822  : LLVM::Linkage::External;
823  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
824  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
825  /*alignment=*/0, mapToAddressSpace(clientAPI, storageClass));
826 
827  // Attach location attribute if applicable
828  if (op.getLocationAttr())
829  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
830 
831  return success();
832  }
833 
834 private:
835  spirv::ClientAPI clientAPI;
836 };
837 
838 /// Converts SPIR-V cast ops that do not have straightforward LLVM
839 /// equivalent in LLVM dialect.
840 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
841 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
842 public:
844 
845  LogicalResult
846  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
847  ConversionPatternRewriter &rewriter) const override {
848 
849  Type fromType = op.getOperand().getType();
850  Type toType = op.getType();
851 
852  auto dstType = this->typeConverter.convertType(toType);
853  if (!dstType)
854  return rewriter.notifyMatchFailure(op, "type conversion failed");
855 
856  if (getBitWidth(fromType) < getBitWidth(toType)) {
857  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
858  adaptor.getOperands());
859  return success();
860  }
861  if (getBitWidth(fromType) > getBitWidth(toType)) {
862  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
863  adaptor.getOperands());
864  return success();
865  }
866  return failure();
867  }
868 };
869 
870 class FunctionCallPattern
871  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
872 public:
874 
875  LogicalResult
876  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
877  ConversionPatternRewriter &rewriter) const override {
878  if (callOp.getNumResults() == 0) {
879  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
880  callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
881  return success();
882  }
883 
884  // Function returns a single result.
885  auto dstType = typeConverter.convertType(callOp.getType(0));
886  if (!dstType)
887  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
888  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
889  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
890  return success();
891  }
892 };
893 
894 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
895 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
896 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
897 public:
899 
900  LogicalResult
901  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
902  ConversionPatternRewriter &rewriter) const override {
903 
904  auto dstType = this->typeConverter.convertType(op.getType());
905  if (!dstType)
906  return rewriter.notifyMatchFailure(op, "type conversion failed");
907 
908  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
909  op, dstType, predicate, op.getOperand1(), op.getOperand2());
910  return success();
911  }
912 };
913 
914 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
915 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
916 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
917 public:
919 
920  LogicalResult
921  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
922  ConversionPatternRewriter &rewriter) const override {
923 
924  auto dstType = this->typeConverter.convertType(op.getType());
925  if (!dstType)
926  return rewriter.notifyMatchFailure(op, "type conversion failed");
927 
928  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
929  op, dstType, predicate, op.getOperand1(), op.getOperand2());
930  return success();
931  }
932 };
933 
934 class InverseSqrtPattern
935  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
936 public:
938 
939  LogicalResult
940  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
941  ConversionPatternRewriter &rewriter) const override {
942  auto srcType = op.getType();
943  auto dstType = typeConverter.convertType(srcType);
944  if (!dstType)
945  return rewriter.notifyMatchFailure(op, "type conversion failed");
946 
947  Location loc = op.getLoc();
948  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
949  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
950  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
951  return success();
952  }
953 };
954 
955 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
956 template <typename SPIRVOp>
957 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
958 public:
960 
961  LogicalResult
962  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
963  ConversionPatternRewriter &rewriter) const override {
964  if (!op.getMemoryAccess()) {
965  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
966  this->typeConverter, /*alignment=*/0,
967  /*isVolatile=*/false,
968  /*isNonTemporal=*/false);
969  }
970  auto memoryAccess = *op.getMemoryAccess();
971  switch (memoryAccess) {
972  case spirv::MemoryAccess::Aligned:
974  case spirv::MemoryAccess::Nontemporal:
975  case spirv::MemoryAccess::Volatile: {
976  unsigned alignment =
977  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
978  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
979  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
980  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
981  this->typeConverter, alignment, isVolatile,
982  isNonTemporal);
983  }
984  default:
985  // There is no support of other memory access attributes.
986  return failure();
987  }
988  }
989 };
990 
991 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
992 template <typename SPIRVOp>
993 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
994 public:
996 
997  LogicalResult
998  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
999  ConversionPatternRewriter &rewriter) const override {
1000  auto srcType = notOp.getType();
1001  auto dstType = this->typeConverter.convertType(srcType);
1002  if (!dstType)
1003  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
1004 
1005  Location loc = notOp.getLoc();
1006  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
1007  auto mask =
1008  isa<VectorType>(srcType)
1009  ? rewriter.create<LLVM::ConstantOp>(
1010  loc, dstType,
1011  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
1012  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
1013  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
1014  notOp.getOperand(), mask);
1015  return success();
1016  }
1017 };
1018 
1019 /// A template pattern that erases the given `SPIRVOp`.
1020 template <typename SPIRVOp>
1021 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
1022 public:
1024 
1025  LogicalResult
1026  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1027  ConversionPatternRewriter &rewriter) const override {
1028  rewriter.eraseOp(op);
1029  return success();
1030  }
1031 };
1032 
1033 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1034 public:
1036 
1037  LogicalResult
1038  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1039  ConversionPatternRewriter &rewriter) const override {
1040  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1041  ArrayRef<Value>());
1042  return success();
1043  }
1044 };
1045 
1046 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1047 public:
1049 
1050  LogicalResult
1051  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1052  ConversionPatternRewriter &rewriter) const override {
1053  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1054  adaptor.getOperands());
1055  return success();
1056  }
1057 };
1058 
1059 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1060 /// should be reachable for conversion to succeed. The structure of the loop in
1061 /// LLVM dialect will be the following:
1062 ///
1063 /// +------------------------------------+
1064 /// | <code before spirv.mlir.loop> |
1065 /// | llvm.br ^header |
1066 /// +------------------------------------+
1067 /// |
1068 /// +----------------+ |
1069 /// | | |
1070 /// | V V
1071 /// | +------------------------------------+
1072 /// | | ^header: |
1073 /// | | <header code> |
1074 /// | | llvm.cond_br %cond, ^body, ^exit |
1075 /// | +------------------------------------+
1076 /// | |
1077 /// | |----------------------+
1078 /// | | |
1079 /// | V |
1080 /// | +------------------------------------+ |
1081 /// | | ^body: | |
1082 /// | | <body code> | |
1083 /// | | llvm.br ^continue | |
1084 /// | +------------------------------------+ |
1085 /// | | |
1086 /// | V |
1087 /// | +------------------------------------+ |
1088 /// | | ^continue: | |
1089 /// | | <continue code> | |
1090 /// | | llvm.br ^header | |
1091 /// | +------------------------------------+ |
1092 /// | | |
1093 /// +---------------+ +----------------------+
1094 /// |
1095 /// V
1096 /// +------------------------------------+
1097 /// | ^exit: |
1098 /// | llvm.br ^remaining |
1099 /// +------------------------------------+
1100 /// |
1101 /// V
1102 /// +------------------------------------+
1103 /// | ^remaining: |
1104 /// | <code after spirv.mlir.loop> |
1105 /// +------------------------------------+
1106 ///
1107 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1108 public:
1110 
1111  LogicalResult
1112  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1113  ConversionPatternRewriter &rewriter) const override {
1114  // There is no support of loop control at the moment.
1115  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1116  return failure();
1117 
1118  Location loc = loopOp.getLoc();
1119 
1120  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1121  // be used in `endBlock`.
1122  Block *currentBlock = rewriter.getBlock();
1123  auto position = Block::iterator(loopOp);
1124  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1125 
1126  // Remove entry block and create a branch in the current block going to the
1127  // header block.
1128  Block *entryBlock = loopOp.getEntryBlock();
1129  assert(entryBlock->getOperations().size() == 1);
1130  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1131  if (!brOp)
1132  return failure();
1133  Block *headerBlock = loopOp.getHeaderBlock();
1134  rewriter.setInsertionPointToEnd(currentBlock);
1135  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1136  rewriter.eraseBlock(entryBlock);
1137 
1138  // Branch from merge block to end block.
1139  Block *mergeBlock = loopOp.getMergeBlock();
1140  Operation *terminator = mergeBlock->getTerminator();
1141  ValueRange terminatorOperands = terminator->getOperands();
1142  rewriter.setInsertionPointToEnd(mergeBlock);
1143  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1144 
1145  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1146  rewriter.replaceOp(loopOp, endBlock->getArguments());
1147  return success();
1148  }
1149 };
1150 
1151 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1152 /// block. All blocks within selection should be reachable for conversion to
1153 /// succeed.
1154 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1155 public:
1157 
1158  LogicalResult
1159  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1160  ConversionPatternRewriter &rewriter) const override {
1161  // There is no support for `Flatten` or `DontFlatten` selection control at
1162  // the moment. This are just compiler hints and can be performed during the
1163  // optimization passes.
1164  if (op.getSelectionControl() != spirv::SelectionControl::None)
1165  return failure();
1166 
1167  // `spirv.mlir.selection` should have at least two blocks: one selection
1168  // header block and one merge block. If no blocks are present, or control
1169  // flow branches straight to merge block (two blocks are present), the op is
1170  // redundant and it is erased.
1171  if (op.getBody().getBlocks().size() <= 2) {
1172  rewriter.eraseOp(op);
1173  return success();
1174  }
1175 
1176  Location loc = op.getLoc();
1177 
1178  // Split the current block after `spirv.mlir.selection`. The remaining ops
1179  // will be used in `continueBlock`.
1180  auto *currentBlock = rewriter.getInsertionBlock();
1181  rewriter.setInsertionPointAfter(op);
1182  auto position = rewriter.getInsertionPoint();
1183  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1184 
1185  // Extract conditional branch information from the header block. By SPIR-V
1186  // dialect spec, it should contain `spirv.BranchConditional` or
1187  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1188  // moment in the SPIR-V dialect. Remove this block when finished.
1189  auto *headerBlock = op.getHeaderBlock();
1190  assert(headerBlock->getOperations().size() == 1);
1191  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1192  headerBlock->getOperations().front());
1193  if (!condBrOp)
1194  return failure();
1195  rewriter.eraseBlock(headerBlock);
1196 
1197  // Branch from merge block to continue block.
1198  auto *mergeBlock = op.getMergeBlock();
1199  Operation *terminator = mergeBlock->getTerminator();
1200  ValueRange terminatorOperands = terminator->getOperands();
1201  rewriter.setInsertionPointToEnd(mergeBlock);
1202  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1203 
1204  // Link current block to `true` and `false` blocks within the selection.
1205  Block *trueBlock = condBrOp.getTrueBlock();
1206  Block *falseBlock = condBrOp.getFalseBlock();
1207  rewriter.setInsertionPointToEnd(currentBlock);
1208  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1209  condBrOp.getTrueTargetOperands(),
1210  falseBlock,
1211  condBrOp.getFalseTargetOperands());
1212 
1213  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1214  rewriter.replaceOp(op, continueBlock->getArguments());
1215  return success();
1216  }
1217 };
1218 
1219 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1220 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1221 /// `Shift` is zero or sign extended to match this specification. Cases when
1222 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1223 template <typename SPIRVOp, typename LLVMOp>
1224 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1225 public:
1227 
1228  LogicalResult
1229  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1230  ConversionPatternRewriter &rewriter) const override {
1231 
1232  auto dstType = this->typeConverter.convertType(op.getType());
1233  if (!dstType)
1234  return rewriter.notifyMatchFailure(op, "type conversion failed");
1235 
1236  Type op1Type = op.getOperand1().getType();
1237  Type op2Type = op.getOperand2().getType();
1238 
1239  if (op1Type == op2Type) {
1240  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1241  adaptor.getOperands());
1242  return success();
1243  }
1244 
1245  std::optional<uint64_t> dstTypeWidth =
1247  std::optional<uint64_t> op2TypeWidth =
1249 
1250  if (!dstTypeWidth || !op2TypeWidth)
1251  return failure();
1252 
1253  Location loc = op.getLoc();
1254  Value extended;
1255  if (op2TypeWidth < dstTypeWidth) {
1256  if (isUnsignedIntegerOrVector(op2Type)) {
1257  extended = rewriter.template create<LLVM::ZExtOp>(
1258  loc, dstType, adaptor.getOperand2());
1259  } else {
1260  extended = rewriter.template create<LLVM::SExtOp>(
1261  loc, dstType, adaptor.getOperand2());
1262  }
1263  } else if (op2TypeWidth == dstTypeWidth) {
1264  extended = adaptor.getOperand2();
1265  } else {
1266  return failure();
1267  }
1268 
1269  Value result = rewriter.template create<LLVMOp>(
1270  loc, dstType, adaptor.getOperand1(), extended);
1271  rewriter.replaceOp(op, result);
1272  return success();
1273  }
1274 };
1275 
1276 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1277 public:
1279 
1280  LogicalResult
1281  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1282  ConversionPatternRewriter &rewriter) const override {
1283  auto dstType = typeConverter.convertType(tanOp.getType());
1284  if (!dstType)
1285  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1286 
1287  Location loc = tanOp.getLoc();
1288  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1289  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1290  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1291  return success();
1292  }
1293 };
1294 
1295 /// Convert `spirv.Tanh` to
1296 ///
1297 /// exp(2x) - 1
1298 /// -----------
1299 /// exp(2x) + 1
1300 ///
1301 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1302 public:
1304 
1305  LogicalResult
1306  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1307  ConversionPatternRewriter &rewriter) const override {
1308  auto srcType = tanhOp.getType();
1309  auto dstType = typeConverter.convertType(srcType);
1310  if (!dstType)
1311  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1312 
1313  Location loc = tanhOp.getLoc();
1314  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1315  Value multiplied =
1316  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1317  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1318  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1319  Value numerator =
1320  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1321  Value denominator =
1322  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1323  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1324  denominator);
1325  return success();
1326  }
1327 };
1328 
1329 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1330 public:
1332 
1333  LogicalResult
1334  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1335  ConversionPatternRewriter &rewriter) const override {
1336  auto srcType = varOp.getType();
1337  // Initialization is supported for scalars and vectors only.
1338  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1339  auto init = varOp.getInitializer();
1340  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1341  return failure();
1342 
1343  auto dstType = typeConverter.convertType(srcType);
1344  if (!dstType)
1345  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1346 
1347  Location loc = varOp.getLoc();
1348  Value size = createI32ConstantOf(loc, rewriter, 1);
1349  if (!init) {
1350  auto elementType = typeConverter.convertType(pointerTo);
1351  if (!elementType)
1352  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1353  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1354  size);
1355  return success();
1356  }
1357  auto elementType = typeConverter.convertType(pointerTo);
1358  if (!elementType)
1359  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1360  Value allocated =
1361  rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1362  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1363  rewriter.replaceOp(varOp, allocated);
1364  return success();
1365  }
1366 };
1367 
1368 //===----------------------------------------------------------------------===//
1369 // BitcastOp conversion
1370 //===----------------------------------------------------------------------===//
1371 
1372 class BitcastConversionPattern
1373  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1374 public:
1376 
1377  LogicalResult
1378  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1379  ConversionPatternRewriter &rewriter) const override {
1380  auto dstType = typeConverter.convertType(bitcastOp.getType());
1381  if (!dstType)
1382  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1383 
1384  // LLVM's opaque pointers do not require bitcasts.
1385  if (isa<LLVM::LLVMPointerType>(dstType)) {
1386  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1387  return success();
1388  }
1389 
1390  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1391  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1392  return success();
1393  }
1394 };
1395 
1396 //===----------------------------------------------------------------------===//
1397 // FuncOp conversion
1398 //===----------------------------------------------------------------------===//
1399 
1400 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1401 public:
1403 
1404  LogicalResult
1405  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1406  ConversionPatternRewriter &rewriter) const override {
1407 
1408  // Convert function signature. At the moment LLVMType converter is enough
1409  // for currently supported types.
1410  auto funcType = funcOp.getFunctionType();
1411  TypeConverter::SignatureConversion signatureConverter(
1412  funcType.getNumInputs());
1413  auto llvmType = typeConverter.convertFunctionSignature(
1414  funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
1415  signatureConverter);
1416  if (!llvmType)
1417  return failure();
1418 
1419  // Create a new `LLVMFuncOp`
1420  Location loc = funcOp.getLoc();
1421  StringRef name = funcOp.getName();
1422  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1423 
1424  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1425  MLIRContext *context = funcOp.getContext();
1426  switch (funcOp.getFunctionControl()) {
1427  case spirv::FunctionControl::Inline:
1428  newFuncOp.setAlwaysInline(true);
1429  break;
1430  case spirv::FunctionControl::DontInline:
1431  newFuncOp.setNoInline(true);
1432  break;
1433 
1434 #define DISPATCH(functionControl, llvmAttr) \
1435  case functionControl: \
1436  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1437  break;
1438 
1439  DISPATCH(spirv::FunctionControl::Pure,
1440  StringAttr::get(context, "readonly"));
1441  DISPATCH(spirv::FunctionControl::Const,
1442  StringAttr::get(context, "readnone"));
1443 
1444 #undef DISPATCH
1445 
1446  // Default: if `spirv::FunctionControl::None`, then no attributes are
1447  // needed.
1448  default:
1449  break;
1450  }
1451 
1452  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1453  newFuncOp.end());
1454  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1455  &signatureConverter))) {
1456  return failure();
1457  }
1458  rewriter.eraseOp(funcOp);
1459  return success();
1460  }
1461 };
1462 
1463 //===----------------------------------------------------------------------===//
1464 // ModuleOp conversion
1465 //===----------------------------------------------------------------------===//
1466 
1467 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1468 public:
1470 
1471  LogicalResult
1472  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1473  ConversionPatternRewriter &rewriter) const override {
1474 
1475  auto newModuleOp =
1476  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1477  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1478 
1479  // Remove the terminator block that was automatically added by builder
1480  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1481  rewriter.eraseOp(spvModuleOp);
1482  return success();
1483  }
1484 };
1485 
1486 //===----------------------------------------------------------------------===//
1487 // VectorShuffleOp conversion
1488 //===----------------------------------------------------------------------===//
1489 
1490 class VectorShufflePattern
1491  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1492 public:
1494  LogicalResult
1495  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1496  ConversionPatternRewriter &rewriter) const override {
1497  Location loc = op.getLoc();
1498  auto components = adaptor.getComponents();
1499  auto vector1 = adaptor.getVector1();
1500  auto vector2 = adaptor.getVector2();
1501  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1502  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1503  if (vector1Size == vector2Size) {
1504  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1505  op, vector1, vector2,
1506  LLVM::convertArrayToIndices<int32_t>(components));
1507  return success();
1508  }
1509 
1510  auto dstType = typeConverter.convertType(op.getType());
1511  if (!dstType)
1512  return rewriter.notifyMatchFailure(op, "type conversion failed");
1513  auto scalarType = cast<VectorType>(dstType).getElementType();
1514  auto componentsArray = components.getValue();
1515  auto *context = rewriter.getContext();
1516  auto llvmI32Type = IntegerType::get(context, 32);
1517  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1518  for (unsigned i = 0; i < componentsArray.size(); i++) {
1519  if (!isa<IntegerAttr>(componentsArray[i]))
1520  return op.emitError("unable to support non-constant component");
1521 
1522  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1523  if (indexVal == -1)
1524  continue;
1525 
1526  int offsetVal = 0;
1527  Value baseVector = vector1;
1528  if (indexVal >= vector1Size) {
1529  offsetVal = vector1Size;
1530  baseVector = vector2;
1531  }
1532 
1533  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1534  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1535  Value index = rewriter.create<LLVM::ConstantOp>(
1536  loc, llvmI32Type,
1537  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1538 
1539  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1540  loc, scalarType, baseVector, index);
1541  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1542  extractOp, dstIndex);
1543  }
1544  rewriter.replaceOp(op, targetOp);
1545  return success();
1546  }
1547 };
1548 } // namespace
1549 
1550 //===----------------------------------------------------------------------===//
1551 // Pattern population
1552 //===----------------------------------------------------------------------===//
1553 
1555  spirv::ClientAPI clientAPI) {
1556  typeConverter.addConversion([&](spirv::ArrayType type) {
1557  return convertArrayType(type, typeConverter);
1558  });
1559  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1560  return convertPointerType(type, typeConverter, clientAPI);
1561  });
1562  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1563  return convertRuntimeArrayType(type, typeConverter);
1564  });
1565  typeConverter.addConversion([&](spirv::StructType type) {
1566  return convertStructType(type, typeConverter);
1567  });
1568 }
1569 
1571  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1572  spirv::ClientAPI clientAPI) {
1573  patterns.add<
1574  // Arithmetic ops
1575  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1576  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1577  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1578  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1579  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1580  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1581  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1582  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1583  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1584  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1585  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1586  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1587  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1588 
1589  // Bitwise ops
1590  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1591  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1592  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1593  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1594  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1595  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1596  NotPattern<spirv::NotOp>,
1597 
1598  // Cast ops
1599  BitcastConversionPattern,
1600  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1601  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1602  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1603  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1604  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1605  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1606  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1607 
1608  // Comparison ops
1609  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1610  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1611  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1612  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1613  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1614  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1615  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1616  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1617  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1618  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1619  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1620  LLVM::FCmpPredicate::uge>,
1621  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1622  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1623  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1624  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1625  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1626  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1627  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1628  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1629  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1630  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1631  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1632 
1633  // Constant op
1634  ConstantScalarAndVectorPattern,
1635 
1636  // Control Flow ops
1637  BranchConversionPattern, BranchConditionalConversionPattern,
1638  FunctionCallPattern, LoopPattern, SelectionPattern,
1639  ErasePattern<spirv::MergeOp>,
1640 
1641  // Entry points and execution mode are handled separately.
1642  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1643 
1644  // GLSL extended instruction set ops
1645  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1646  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1647  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1648  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1649  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1650  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1651  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1652  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1653  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1654  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1655  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1656  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1657  InverseSqrtPattern, TanPattern, TanhPattern,
1658 
1659  // Logical ops
1660  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1661  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1662  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1663  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1664  NotPattern<spirv::LogicalNotOp>,
1665 
1666  // Memory ops
1667  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1668  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1669 
1670  // Miscellaneous ops
1671  CompositeExtractPattern, CompositeInsertPattern,
1672  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1673  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1674  VectorShufflePattern,
1675 
1676  // Shift ops
1677  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1678  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1679  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1680 
1681  // Return ops
1682  ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1683 
1684  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1685  typeConverter);
1686 }
1687 
1689  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1690  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1691 }
1692 
1694  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1695  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1696 }
1697 
1698 //===----------------------------------------------------------------------===//
1699 // Pre-conversion hooks
1700 //===----------------------------------------------------------------------===//
1701 
1702 /// Hook for descriptor set and binding number encoding.
1703 static constexpr StringRef kBinding = "binding";
1704 static constexpr StringRef kDescriptorSet = "descriptor_set";
1705 void mlir::encodeBindAttribute(ModuleOp module) {
1706  auto spvModules = module.getOps<spirv::ModuleOp>();
1707  for (auto spvModule : spvModules) {
1708  spvModule.walk([&](spirv::GlobalVariableOp op) {
1709  IntegerAttr descriptorSet =
1710  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1711  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1712  // For every global variable in the module, get the ones with descriptor
1713  // set and binding numbers.
1714  if (descriptorSet && binding) {
1715  // Encode these numbers into the variable's symbolic name. If the
1716  // SPIR-V module has a name, add it at the beginning.
1717  auto moduleAndName =
1718  spvModule.getName().has_value()
1719  ? spvModule.getName()->str() + "_" + op.getSymName().str()
1720  : op.getSymName().str();
1721  std::string name =
1722  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1723  std::to_string(descriptorSet.getInt()),
1724  std::to_string(binding.getInt()));
1725  auto nameAttr = StringAttr::get(op->getContext(), name);
1726 
1727  // Replace all symbol uses and set the new symbol name. Finally, remove
1728  // descriptor set and binding attributes.
1729  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1730  op.emitError("unable to replace all symbol uses for ") << name;
1731  SymbolTable::setSymbolName(op, nameAttr);
1733  op->removeAttr(kBinding);
1734  }
1735  });
1736  }
1737 }
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass)
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:71
constexpr unsigned defaultAddressSpace
Definition: SPIRVToLLVM.cpp:35
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
#define STORAGE_SPACE_MAP(storage, space)
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:42
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:51
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
static Type convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:92
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:84
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:61
#define CLIENT_MAP(client, storage)
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
iterator begin()
Definition: Block.h:141
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
IntegerType getI32Type()
Definition: Builders.cpp:87
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:453
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:448
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:439
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:451
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:445
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:595
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:80
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:92
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:120
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:487
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:548
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:874
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:220
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:890
Include the generated interface declarations.
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.