MLIR  20.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38  if (type.isSignedInteger())
39  return true;
40  if (auto vecType = dyn_cast<VectorType>(type))
41  return vecType.getElementType().isSignedInteger();
42  return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47  if (type.isUnsignedInteger())
48  return true;
49  if (auto vecType = dyn_cast<VectorType>(type))
50  return vecType.getElementType().isUnsignedInteger();
51  return false;
52 }
53 
54 /// Returns the width of an integer or of the element type of an integer vector,
55 /// if applicable.
56 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
57  if (auto intType = dyn_cast<IntegerType>(type))
58  return intType.getWidth();
59  if (auto vecType = dyn_cast<VectorType>(type))
60  if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
61  return intType.getWidth();
62  return std::nullopt;
63 }
64 
65 /// Returns the bit width of integer, float or vector of float or integer values
66 static unsigned getBitWidth(Type type) {
67  assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
68  "bitwidth is not supported for this type");
69  if (type.isIntOrFloat())
70  return type.getIntOrFloatBitWidth();
71  auto vecType = dyn_cast<VectorType>(type);
72  auto elementType = vecType.getElementType();
73  assert(elementType.isIntOrFloat() &&
74  "only integers and floats have a bitwidth");
75  return elementType.getIntOrFloatBitWidth();
76 }
77 
78 /// Returns the bit width of LLVMType integer or vector.
79 static unsigned getLLVMTypeBitWidth(Type type) {
80  return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
82  : type))
83  .getWidth();
84 }
85 
86 /// Creates `IntegerAttribute` with all bits set for given type
87 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
88  if (auto vecType = dyn_cast<VectorType>(type)) {
89  auto integerType = cast<IntegerType>(vecType.getElementType());
90  return builder.getIntegerAttr(integerType, -1);
91  }
92  auto integerType = cast<IntegerType>(type);
93  return builder.getIntegerAttr(integerType, -1);
94 }
95 
96 /// Creates `llvm.mlir.constant` with all bits set for the given type.
97 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
98  PatternRewriter &rewriter) {
99  if (isa<VectorType>(srcType)) {
100  return rewriter.create<LLVM::ConstantOp>(
101  loc, dstType,
102  SplatElementsAttr::get(cast<ShapedType>(srcType),
103  minusOneIntegerAttribute(srcType, rewriter)));
104  }
105  return rewriter.create<LLVM::ConstantOp>(
106  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
107 }
108 
109 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
110 static Value createFPConstant(Location loc, Type srcType, Type dstType,
111  PatternRewriter &rewriter, double value) {
112  if (auto vecType = dyn_cast<VectorType>(srcType)) {
113  auto floatType = cast<FloatType>(vecType.getElementType());
114  return rewriter.create<LLVM::ConstantOp>(
115  loc, dstType,
116  SplatElementsAttr::get(vecType,
117  rewriter.getFloatAttr(floatType, value)));
118  }
119  auto floatType = cast<FloatType>(srcType);
120  return rewriter.create<LLVM::ConstantOp>(
121  loc, dstType, rewriter.getFloatAttr(floatType, value));
122 }
123 
124 /// Utility function for bitfield ops:
125 /// - `BitFieldInsert`
126 /// - `BitFieldSExtract`
127 /// - `BitFieldUExtract`
128 /// Truncates or extends the value. If the bitwidth of the value is the same as
129 /// `llvmType` bitwidth, the value remains unchanged.
131  Type llvmType,
132  PatternRewriter &rewriter) {
133  auto srcType = value.getType();
134  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
135  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
136  ? getLLVMTypeBitWidth(srcType)
137  : getBitWidth(srcType);
138 
139  if (valueBitWidth < targetBitWidth)
140  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
141  // If the bit widths of `Count` and `Offset` are greater than the bit width
142  // of the target type, they are truncated. Truncation is safe since `Count`
143  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
144  // both values can be expressed in 8 bits.
145  if (valueBitWidth > targetBitWidth)
146  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
147  return value;
148 }
149 
150 /// Broadcasts the value to vector with `numElements` number of elements.
151 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
152  const TypeConverter &typeConverter,
153  ConversionPatternRewriter &rewriter) {
154  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
155  auto llvmVectorType = typeConverter.convertType(vectorType);
156  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
157  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
158  for (unsigned i = 0; i < numElements; ++i) {
159  auto index = rewriter.create<LLVM::ConstantOp>(
160  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
161  broadcasted = rewriter.create<LLVM::InsertElementOp>(
162  loc, llvmVectorType, broadcasted, toBroadcast, index);
163  }
164  return broadcasted;
165 }
166 
167 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
168 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
169  const TypeConverter &typeConverter,
170  ConversionPatternRewriter &rewriter) {
171  if (auto vectorType = dyn_cast<VectorType>(srcType)) {
172  unsigned numElements = vectorType.getNumElements();
173  return broadcast(loc, value, numElements, typeConverter, rewriter);
174  }
175  return value;
176 }
177 
178 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
179 /// `BitFieldUExtract`.
180 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
181 /// a vector type, construct a vector that has:
182 /// - same number of elements as `Base`
183 /// - each element has the type that is the same as the type of `Offset` or
184 /// `Count`
185 /// - each element has the same value as `Offset` or `Count`
186 /// Then cast `Offset` and `Count` if their bit width is different
187 /// from `Base` bit width.
188 static Value processCountOrOffset(Location loc, Value value, Type srcType,
189  Type dstType, const TypeConverter &converter,
190  ConversionPatternRewriter &rewriter) {
191  Value broadcasted =
192  optionallyBroadcast(loc, value, srcType, converter, rewriter);
193  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
194 }
195 
196 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
197 /// offset to LLVM struct. Otherwise, the conversion is not supported.
199  const TypeConverter &converter) {
200  if (type != VulkanLayoutUtils::decorateType(type))
201  return nullptr;
202 
203  SmallVector<Type> elementsVector;
204  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
205  return nullptr;
206  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
207  /*isPacked=*/false);
208 }
209 
210 /// Converts SPIR-V struct with no offset to packed LLVM struct.
212  const TypeConverter &converter) {
213  SmallVector<Type> elementsVector;
214  if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
215  return nullptr;
216  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
217  /*isPacked=*/true);
218 }
219 
220 /// Creates LLVM dialect constant with the given value.
222  unsigned value) {
223  return rewriter.create<LLVM::ConstantOp>(
224  loc, IntegerType::get(rewriter.getContext(), 32),
225  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
226 }
227 
228 /// Utility for `spirv.Load` and `spirv.Store` conversion.
229 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
230  ConversionPatternRewriter &rewriter,
231  const TypeConverter &typeConverter,
232  unsigned alignment, bool isVolatile,
233  bool isNonTemporal) {
234  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
235  auto dstType = typeConverter.convertType(loadOp.getType());
236  if (!dstType)
237  return rewriter.notifyMatchFailure(op, "type conversion failed");
238  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
239  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
240  isVolatile, isNonTemporal);
241  return success();
242  }
243  auto storeOp = cast<spirv::StoreOp>(op);
244  spirv::StoreOpAdaptor adaptor(operands);
245  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
246  adaptor.getPtr(), alignment,
247  isVolatile, isNonTemporal);
248  return success();
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Type conversion
253 //===----------------------------------------------------------------------===//
254 
255 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
256 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
257 /// when converting ops that manipulate array types.
258 static std::optional<Type> convertArrayType(spirv::ArrayType type,
259  TypeConverter &converter) {
260  unsigned stride = type.getArrayStride();
261  Type elementType = type.getElementType();
262  auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
263  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
264  return std::nullopt;
265 
266  auto llvmElementType = converter.convertType(elementType);
267  unsigned numElements = type.getNumElements();
268  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
269 }
270 
271 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
272 /// modelled at the moment.
274  const TypeConverter &converter,
275  spirv::ClientAPI clientAPI) {
276  unsigned addressSpace =
277  storageClassToAddressSpace(clientAPI, type.getStorageClass());
278  return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
279 }
280 
281 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
282 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
283 /// no modelling of array stride at the moment.
284 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
285  TypeConverter &converter) {
286  if (type.getArrayStride() != 0)
287  return std::nullopt;
288  auto elementType = converter.convertType(type.getElementType());
289  return LLVM::LLVMArrayType::get(elementType, 0);
290 }
291 
292 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
293 /// member decorations. Also, only natural offset is supported.
295  const TypeConverter &converter) {
297  type.getMemberDecorations(memberDecorations);
298  if (!memberDecorations.empty())
299  return nullptr;
300  if (type.hasOffset())
301  return convertStructTypeWithOffset(type, converter);
302  return convertStructTypePacked(type, converter);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Operation conversion
307 //===----------------------------------------------------------------------===//
308 
309 namespace {
310 
311 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
312 public:
314 
315  LogicalResult
316  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
317  ConversionPatternRewriter &rewriter) const override {
318  auto dstType =
319  getTypeConverter()->convertType(op.getComponentPtr().getType());
320  if (!dstType)
321  return rewriter.notifyMatchFailure(op, "type conversion failed");
322  // To use GEP we need to add a first 0 index to go through the pointer.
323  auto indices = llvm::to_vector<4>(adaptor.getIndices());
324  Type indexType = op.getIndices().front().getType();
325  auto llvmIndexType = getTypeConverter()->convertType(indexType);
326  if (!llvmIndexType)
327  return rewriter.notifyMatchFailure(op, "type conversion failed");
328  Value zero = rewriter.create<LLVM::ConstantOp>(
329  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
330  indices.insert(indices.begin(), zero);
331 
332  auto elementType = getTypeConverter()->convertType(
333  cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
334  if (!elementType)
335  return rewriter.notifyMatchFailure(op, "type conversion failed");
336  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
337  adaptor.getBasePtr(), indices);
338  return success();
339  }
340 };
341 
342 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
343 public:
345 
346  LogicalResult
347  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
348  ConversionPatternRewriter &rewriter) const override {
349  auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
350  if (!dstType)
351  return rewriter.notifyMatchFailure(op, "type conversion failed");
352  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
353  op.getVariable());
354  return success();
355  }
356 };
357 
358 class BitFieldInsertPattern
359  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
360 public:
362 
363  LogicalResult
364  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
365  ConversionPatternRewriter &rewriter) const override {
366  auto srcType = op.getType();
367  auto dstType = getTypeConverter()->convertType(srcType);
368  if (!dstType)
369  return rewriter.notifyMatchFailure(op, "type conversion failed");
370  Location loc = op.getLoc();
371 
372  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
373  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
374  *getTypeConverter(), rewriter);
375  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
376  *getTypeConverter(), rewriter);
377 
378  // Create a mask with bits set outside [Offset, Offset + Count - 1].
379  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
380  Value maskShiftedByCount =
381  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
382  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
383  maskShiftedByCount, minusOne);
384  Value maskShiftedByCountAndOffset =
385  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
386  Value mask = rewriter.create<LLVM::XOrOp>(
387  loc, dstType, maskShiftedByCountAndOffset, minusOne);
388 
389  // Extract unchanged bits from the `Base` that are outside of
390  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
391  Value baseAndMask =
392  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
393  Value insertShiftedByOffset =
394  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
395  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
396  insertShiftedByOffset);
397  return success();
398  }
399 };
400 
401 /// Converts SPIR-V ConstantOp with scalar or vector type.
402 class ConstantScalarAndVectorPattern
403  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
404 public:
406 
407  LogicalResult
408  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
409  ConversionPatternRewriter &rewriter) const override {
410  auto srcType = constOp.getType();
411  if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
412  return failure();
413 
414  auto dstType = getTypeConverter()->convertType(srcType);
415  if (!dstType)
416  return rewriter.notifyMatchFailure(constOp, "type conversion failed");
417 
418  // SPIR-V constant can be a signed/unsigned integer, which has to be
419  // casted to signless integer when converting to LLVM dialect. Removing the
420  // sign bit may have unexpected behaviour. However, it is better to handle
421  // it case-by-case, given that the purpose of the conversion is not to
422  // cover all possible corner cases.
423  if (isSignedIntegerOrVector(srcType) ||
424  isUnsignedIntegerOrVector(srcType)) {
425  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
426 
427  if (isa<VectorType>(srcType)) {
428  auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
429  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
430  constOp, dstType,
431  dstElementsAttr.mapValues(
432  signlessType, [&](const APInt &value) { return value; }));
433  return success();
434  }
435  auto srcAttr = cast<IntegerAttr>(constOp.getValue());
436  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
437  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
438  return success();
439  }
440  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
441  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
442  return success();
443  }
444 };
445 
446 class BitFieldSExtractPattern
447  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
448 public:
450 
451  LogicalResult
452  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
453  ConversionPatternRewriter &rewriter) const override {
454  auto srcType = op.getType();
455  auto dstType = getTypeConverter()->convertType(srcType);
456  if (!dstType)
457  return rewriter.notifyMatchFailure(op, "type conversion failed");
458  Location loc = op.getLoc();
459 
460  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
461  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
462  *getTypeConverter(), rewriter);
463  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
464  *getTypeConverter(), rewriter);
465 
466  // Create a constant that holds the size of the `Base`.
467  IntegerType integerType;
468  if (auto vecType = dyn_cast<VectorType>(srcType))
469  integerType = cast<IntegerType>(vecType.getElementType());
470  else
471  integerType = cast<IntegerType>(srcType);
472 
473  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
474  Value size =
475  isa<VectorType>(srcType)
476  ? rewriter.create<LLVM::ConstantOp>(
477  loc, dstType,
478  SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
479  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
480 
481  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
482  // at Offset + Count - 1 is the most significant bit now.
483  Value countPlusOffset =
484  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
485  Value amountToShiftLeft =
486  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
487  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
488  loc, dstType, op.getBase(), amountToShiftLeft);
489 
490  // Shift the result right, filling the bits with the sign bit.
491  Value amountToShiftRight =
492  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
493  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
494  amountToShiftRight);
495  return success();
496  }
497 };
498 
499 class BitFieldUExtractPattern
500  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
501 public:
503 
504  LogicalResult
505  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
506  ConversionPatternRewriter &rewriter) const override {
507  auto srcType = op.getType();
508  auto dstType = getTypeConverter()->convertType(srcType);
509  if (!dstType)
510  return rewriter.notifyMatchFailure(op, "type conversion failed");
511  Location loc = op.getLoc();
512 
513  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
514  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
515  *getTypeConverter(), rewriter);
516  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
517  *getTypeConverter(), rewriter);
518 
519  // Create a mask with bits set at [0, Count - 1].
520  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
521  Value maskShiftedByCount =
522  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
523  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
524  minusOne);
525 
526  // Shift `Base` by `Offset` and apply the mask on it.
527  Value shiftedBase =
528  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
529  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
530  return success();
531  }
532 };
533 
534 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
535 public:
537 
538  LogicalResult
539  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
540  ConversionPatternRewriter &rewriter) const override {
541  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
542  branchOp.getTarget());
543  return success();
544  }
545 };
546 
547 class BranchConditionalConversionPattern
548  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
549 public:
550  using SPIRVToLLVMConversion<
551  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
552 
553  LogicalResult
554  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
555  ConversionPatternRewriter &rewriter) const override {
556  // If branch weights exist, map them to 32-bit integer vector.
557  DenseI32ArrayAttr branchWeights = nullptr;
558  if (auto weights = op.getBranchWeights()) {
559  SmallVector<int32_t> weightValues;
560  for (auto weight : weights->getAsRange<IntegerAttr>())
561  weightValues.push_back(weight.getInt());
562  branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
563  }
564 
565  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
566  op, op.getCondition(), op.getTrueBlockArguments(),
567  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
568  op.getFalseBlock());
569  return success();
570  }
571 };
572 
573 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
574 /// type is an aggregate type (struct or array). Otherwise, converts to
575 /// `llvm.extractelement` that operates on vectors.
576 class CompositeExtractPattern
577  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
578 public:
580 
581  LogicalResult
582  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
583  ConversionPatternRewriter &rewriter) const override {
584  auto dstType = this->getTypeConverter()->convertType(op.getType());
585  if (!dstType)
586  return rewriter.notifyMatchFailure(op, "type conversion failed");
587 
588  Type containerType = op.getComposite().getType();
589  if (isa<VectorType>(containerType)) {
590  Location loc = op.getLoc();
591  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
592  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
593  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
594  op, dstType, adaptor.getComposite(), index);
595  return success();
596  }
597 
598  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
599  op, adaptor.getComposite(),
600  LLVM::convertArrayToIndices(op.getIndices()));
601  return success();
602  }
603 };
604 
605 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
606 /// type is an aggregate type (struct or array). Otherwise, converts to
607 /// `llvm.insertelement` that operates on vectors.
608 class CompositeInsertPattern
609  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
610 public:
612 
613  LogicalResult
614  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
615  ConversionPatternRewriter &rewriter) const override {
616  auto dstType = this->getTypeConverter()->convertType(op.getType());
617  if (!dstType)
618  return rewriter.notifyMatchFailure(op, "type conversion failed");
619 
620  Type containerType = op.getComposite().getType();
621  if (isa<VectorType>(containerType)) {
622  Location loc = op.getLoc();
623  IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
624  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
625  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
626  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
627  return success();
628  }
629 
630  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
631  op, adaptor.getComposite(), adaptor.getObject(),
632  LLVM::convertArrayToIndices(op.getIndices()));
633  return success();
634  }
635 };
636 
637 /// Converts SPIR-V operations that have straightforward LLVM equivalent
638 /// into LLVM dialect operations.
639 template <typename SPIRVOp, typename LLVMOp>
640 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
641 public:
643 
644  LogicalResult
645  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
646  ConversionPatternRewriter &rewriter) const override {
647  auto dstType = this->getTypeConverter()->convertType(op.getType());
648  if (!dstType)
649  return rewriter.notifyMatchFailure(op, "type conversion failed");
650  rewriter.template replaceOpWithNewOp<LLVMOp>(
651  op, dstType, adaptor.getOperands(), op->getAttrs());
652  return success();
653  }
654 };
655 
656 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
657 /// execution mode information.
658 class ExecutionModePattern
659  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
660 public:
662 
663  LogicalResult
664  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
665  ConversionPatternRewriter &rewriter) const override {
666  // First, create the global struct's name that would be associated with
667  // this entry point's execution mode. We set it to be:
668  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
669  ModuleOp module = op->getParentOfType<ModuleOp>();
670  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
671  std::string moduleName;
672  if (module.getName().has_value())
673  moduleName = "_" + module.getName()->str();
674  else
675  moduleName = "";
676  std::string executionModeInfoName = llvm::formatv(
677  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
678  static_cast<uint32_t>(executionModeAttr.getValue()));
679 
680  MLIRContext *context = rewriter.getContext();
681  OpBuilder::InsertionGuard guard(rewriter);
682  rewriter.setInsertionPointToStart(module.getBody());
683 
684  // Create a struct type, corresponding to the C struct below.
685  // struct {
686  // int32_t executionMode;
687  // int32_t values[]; // optional values
688  // };
689  auto llvmI32Type = IntegerType::get(context, 32);
690  SmallVector<Type, 2> fields;
691  fields.push_back(llvmI32Type);
692  ArrayAttr values = op.getValues();
693  if (!values.empty()) {
694  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
695  fields.push_back(arrayType);
696  }
697  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
698 
699  // Create `llvm.mlir.global` with initializer region containing one block.
700  auto global = rewriter.create<LLVM::GlobalOp>(
701  UnknownLoc::get(context), structType, /*isConstant=*/true,
702  LLVM::Linkage::External, executionModeInfoName, Attribute(),
703  /*alignment=*/0);
704  Location loc = global.getLoc();
705  Region &region = global.getInitializerRegion();
706  Block *block = rewriter.createBlock(&region);
707 
708  // Initialize the struct and set the execution mode value.
709  rewriter.setInsertionPoint(block, block->begin());
710  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
711  Value executionMode = rewriter.create<LLVM::ConstantOp>(
712  loc, llvmI32Type,
713  rewriter.getI32IntegerAttr(
714  static_cast<uint32_t>(executionModeAttr.getValue())));
715  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
716  executionMode, 0);
717 
718  // Insert extra operands if they exist into execution mode info struct.
719  for (unsigned i = 0, e = values.size(); i < e; ++i) {
720  auto attr = values.getValue()[i];
721  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
722  structValue = rewriter.create<LLVM::InsertValueOp>(
723  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
724  }
725  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
726  rewriter.eraseOp(op);
727  return success();
728  }
729 };
730 
731 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
732 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
733 /// value. This difference is handled by `spirv.mlir.addressof` and
734 /// `llvm.mlir.addressof`ops that both return a pointer.
735 class GlobalVariablePattern
736  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
737 public:
738  template <typename... Args>
739  GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
740  : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
741  std::forward<Args>(args)...),
742  clientAPI(clientAPI) {}
743 
744  LogicalResult
745  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
746  ConversionPatternRewriter &rewriter) const override {
747  // Currently, there is no support of initialization with a constant value in
748  // SPIR-V dialect. Specialization constants are not considered as well.
749  if (op.getInitializer())
750  return failure();
751 
752  auto srcType = cast<spirv::PointerType>(op.getType());
753  auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
754  if (!dstType)
755  return rewriter.notifyMatchFailure(op, "type conversion failed");
756 
757  // Limit conversion to the current invocation only or `StorageBuffer`
758  // required by SPIR-V runner.
759  // This is okay because multiple invocations are not supported yet.
760  auto storageClass = srcType.getStorageClass();
761  switch (storageClass) {
762  case spirv::StorageClass::Input:
763  case spirv::StorageClass::Private:
764  case spirv::StorageClass::Output:
765  case spirv::StorageClass::StorageBuffer:
766  case spirv::StorageClass::UniformConstant:
767  break;
768  default:
769  return failure();
770  }
771 
772  // LLVM dialect spec: "If the global value is a constant, storing into it is
773  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
774  // storage class that is read-only.
775  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
776  (storageClass == spirv::StorageClass::UniformConstant);
777  // SPIR-V spec: "By default, functions and global variables are private to a
778  // module and cannot be accessed by other modules. However, a module may be
779  // written to export or import functions and global (module scope)
780  // variables.". Therefore, map 'Private' storage class to private linkage,
781  // 'Input' and 'Output' to external linkage.
782  auto linkage = storageClass == spirv::StorageClass::Private
783  ? LLVM::Linkage::Private
784  : LLVM::Linkage::External;
785  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
786  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
787  /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
788 
789  // Attach location attribute if applicable
790  if (op.getLocationAttr())
791  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
792 
793  return success();
794  }
795 
796 private:
797  spirv::ClientAPI clientAPI;
798 };
799 
800 /// Converts SPIR-V cast ops that do not have straightforward LLVM
801 /// equivalent in LLVM dialect.
802 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
803 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
804 public:
806 
807  LogicalResult
808  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
809  ConversionPatternRewriter &rewriter) const override {
810 
811  Type fromType = op.getOperand().getType();
812  Type toType = op.getType();
813 
814  auto dstType = this->getTypeConverter()->convertType(toType);
815  if (!dstType)
816  return rewriter.notifyMatchFailure(op, "type conversion failed");
817 
818  if (getBitWidth(fromType) < getBitWidth(toType)) {
819  rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
820  adaptor.getOperands());
821  return success();
822  }
823  if (getBitWidth(fromType) > getBitWidth(toType)) {
824  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
825  adaptor.getOperands());
826  return success();
827  }
828  return failure();
829  }
830 };
831 
832 class FunctionCallPattern
833  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
834 public:
836 
837  LogicalResult
838  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
839  ConversionPatternRewriter &rewriter) const override {
840  if (callOp.getNumResults() == 0) {
841  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
842  callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
843  newOp.getProperties().operandSegmentSizes = {
844  static_cast<int32_t>(adaptor.getOperands().size()), 0};
845  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
846  return success();
847  }
848 
849  // Function returns a single result.
850  auto dstType = getTypeConverter()->convertType(callOp.getType(0));
851  if (!dstType)
852  return rewriter.notifyMatchFailure(callOp, "type conversion failed");
853  auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
854  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
855  newOp.getProperties().operandSegmentSizes = {
856  static_cast<int32_t>(adaptor.getOperands().size()), 0};
857  newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
858  return success();
859  }
860 };
861 
862 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
863 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
864 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
865 public:
867 
868  LogicalResult
869  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
870  ConversionPatternRewriter &rewriter) const override {
871 
872  auto dstType = this->getTypeConverter()->convertType(op.getType());
873  if (!dstType)
874  return rewriter.notifyMatchFailure(op, "type conversion failed");
875 
876  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
877  op, dstType, predicate, op.getOperand1(), op.getOperand2());
878  return success();
879  }
880 };
881 
882 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
883 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
884 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
885 public:
887 
888  LogicalResult
889  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
890  ConversionPatternRewriter &rewriter) const override {
891 
892  auto dstType = this->getTypeConverter()->convertType(op.getType());
893  if (!dstType)
894  return rewriter.notifyMatchFailure(op, "type conversion failed");
895 
896  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
897  op, dstType, predicate, op.getOperand1(), op.getOperand2());
898  return success();
899  }
900 };
901 
902 class InverseSqrtPattern
903  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
904 public:
906 
907  LogicalResult
908  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
909  ConversionPatternRewriter &rewriter) const override {
910  auto srcType = op.getType();
911  auto dstType = getTypeConverter()->convertType(srcType);
912  if (!dstType)
913  return rewriter.notifyMatchFailure(op, "type conversion failed");
914 
915  Location loc = op.getLoc();
916  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
917  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
918  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
919  return success();
920  }
921 };
922 
923 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
924 template <typename SPIRVOp>
925 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
926 public:
928 
929  LogicalResult
930  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
931  ConversionPatternRewriter &rewriter) const override {
932  if (!op.getMemoryAccess()) {
933  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
934  *this->getTypeConverter(), /*alignment=*/0,
935  /*isVolatile=*/false,
936  /*isNonTemporal=*/false);
937  }
938  auto memoryAccess = *op.getMemoryAccess();
939  switch (memoryAccess) {
940  case spirv::MemoryAccess::Aligned:
942  case spirv::MemoryAccess::Nontemporal:
943  case spirv::MemoryAccess::Volatile: {
944  unsigned alignment =
945  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
946  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
947  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
948  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
949  *this->getTypeConverter(), alignment,
950  isVolatile, isNonTemporal);
951  }
952  default:
953  // There is no support of other memory access attributes.
954  return failure();
955  }
956  }
957 };
958 
959 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
960 template <typename SPIRVOp>
961 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
962 public:
964 
965  LogicalResult
966  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
967  ConversionPatternRewriter &rewriter) const override {
968  auto srcType = notOp.getType();
969  auto dstType = this->getTypeConverter()->convertType(srcType);
970  if (!dstType)
971  return rewriter.notifyMatchFailure(notOp, "type conversion failed");
972 
973  Location loc = notOp.getLoc();
974  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
975  auto mask =
976  isa<VectorType>(srcType)
977  ? rewriter.create<LLVM::ConstantOp>(
978  loc, dstType,
979  SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
980  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
981  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
982  notOp.getOperand(), mask);
983  return success();
984  }
985 };
986 
987 /// A template pattern that erases the given `SPIRVOp`.
988 template <typename SPIRVOp>
989 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
990 public:
992 
993  LogicalResult
994  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
995  ConversionPatternRewriter &rewriter) const override {
996  rewriter.eraseOp(op);
997  return success();
998  }
999 };
1000 
1001 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1002 public:
1004 
1005  LogicalResult
1006  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1007  ConversionPatternRewriter &rewriter) const override {
1008  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1009  ArrayRef<Value>());
1010  return success();
1011  }
1012 };
1013 
1014 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1015 public:
1017 
1018  LogicalResult
1019  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1020  ConversionPatternRewriter &rewriter) const override {
1021  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1022  adaptor.getOperands());
1023  return success();
1024  }
1025 };
1026 
1027 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1028  StringRef name,
1029  ArrayRef<Type> paramTypes,
1030  Type resultType) {
1031  auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1032  SymbolTable::lookupSymbolIn(symbolTable, name));
1033  if (func)
1034  return func;
1035 
1036  OpBuilder b(symbolTable->getRegion(0));
1037  func = b.create<LLVM::LLVMFuncOp>(
1038  symbolTable->getLoc(), name,
1039  LLVM::LLVMFunctionType::get(resultType, paramTypes));
1040  func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1041  func.setConvergent(true);
1042  func.setNoUnwind(true);
1043  func.setWillReturn(true);
1044  return func;
1045 }
1046 
1047 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1048  LLVM::LLVMFuncOp func,
1049  ValueRange args) {
1050  auto call = builder.create<LLVM::CallOp>(loc, func, args);
1051  call.setCConv(func.getCConv());
1052  call.setConvergentAttr(func.getConvergentAttr());
1053  call.setNoUnwindAttr(func.getNoUnwindAttr());
1054  call.setWillReturnAttr(func.getWillReturnAttr());
1055  return call;
1056 }
1057 
1058 class ControlBarrierPattern
1059  : public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
1060 public:
1062 
1063  LogicalResult
1064  matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
1065  ConversionPatternRewriter &rewriter) const override {
1066  constexpr StringLiteral funcName = "_Z22__spirv_ControlBarrieriii";
1067  Operation *symbolTable =
1068  controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
1069 
1070  Type i32 = rewriter.getI32Type();
1071 
1072  Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1073  LLVM::LLVMFuncOp func =
1074  lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1075 
1076  Location loc = controlBarrierOp->getLoc();
1077  Value execution = rewriter.create<LLVM::ConstantOp>(
1078  loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1079  Value memory = rewriter.create<LLVM::ConstantOp>(
1080  loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1081  Value semantics = rewriter.create<LLVM::ConstantOp>(
1082  loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1083 
1084  auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1085  {execution, memory, semantics});
1086 
1087  rewriter.replaceOp(controlBarrierOp, call);
1088  return success();
1089  }
1090 };
1091 
1092 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1093 /// should be reachable for conversion to succeed. The structure of the loop in
1094 /// LLVM dialect will be the following:
1095 ///
1096 /// +------------------------------------+
1097 /// | <code before spirv.mlir.loop> |
1098 /// | llvm.br ^header |
1099 /// +------------------------------------+
1100 /// |
1101 /// +----------------+ |
1102 /// | | |
1103 /// | V V
1104 /// | +------------------------------------+
1105 /// | | ^header: |
1106 /// | | <header code> |
1107 /// | | llvm.cond_br %cond, ^body, ^exit |
1108 /// | +------------------------------------+
1109 /// | |
1110 /// | |----------------------+
1111 /// | | |
1112 /// | V |
1113 /// | +------------------------------------+ |
1114 /// | | ^body: | |
1115 /// | | <body code> | |
1116 /// | | llvm.br ^continue | |
1117 /// | +------------------------------------+ |
1118 /// | | |
1119 /// | V |
1120 /// | +------------------------------------+ |
1121 /// | | ^continue: | |
1122 /// | | <continue code> | |
1123 /// | | llvm.br ^header | |
1124 /// | +------------------------------------+ |
1125 /// | | |
1126 /// +---------------+ +----------------------+
1127 /// |
1128 /// V
1129 /// +------------------------------------+
1130 /// | ^exit: |
1131 /// | llvm.br ^remaining |
1132 /// +------------------------------------+
1133 /// |
1134 /// V
1135 /// +------------------------------------+
1136 /// | ^remaining: |
1137 /// | <code after spirv.mlir.loop> |
1138 /// +------------------------------------+
1139 ///
1140 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1141 public:
1143 
1144  LogicalResult
1145  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1146  ConversionPatternRewriter &rewriter) const override {
1147  // There is no support of loop control at the moment.
1148  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1149  return failure();
1150 
1151  Location loc = loopOp.getLoc();
1152 
1153  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1154  // be used in `endBlock`.
1155  Block *currentBlock = rewriter.getBlock();
1156  auto position = Block::iterator(loopOp);
1157  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1158 
1159  // Remove entry block and create a branch in the current block going to the
1160  // header block.
1161  Block *entryBlock = loopOp.getEntryBlock();
1162  assert(entryBlock->getOperations().size() == 1);
1163  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1164  if (!brOp)
1165  return failure();
1166  Block *headerBlock = loopOp.getHeaderBlock();
1167  rewriter.setInsertionPointToEnd(currentBlock);
1168  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1169  rewriter.eraseBlock(entryBlock);
1170 
1171  // Branch from merge block to end block.
1172  Block *mergeBlock = loopOp.getMergeBlock();
1173  Operation *terminator = mergeBlock->getTerminator();
1174  ValueRange terminatorOperands = terminator->getOperands();
1175  rewriter.setInsertionPointToEnd(mergeBlock);
1176  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1177 
1178  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1179  rewriter.replaceOp(loopOp, endBlock->getArguments());
1180  return success();
1181  }
1182 };
1183 
1184 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1185 /// block. All blocks within selection should be reachable for conversion to
1186 /// succeed.
1187 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1188 public:
1190 
1191  LogicalResult
1192  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1193  ConversionPatternRewriter &rewriter) const override {
1194  // There is no support for `Flatten` or `DontFlatten` selection control at
1195  // the moment. This are just compiler hints and can be performed during the
1196  // optimization passes.
1197  if (op.getSelectionControl() != spirv::SelectionControl::None)
1198  return failure();
1199 
1200  // `spirv.mlir.selection` should have at least two blocks: one selection
1201  // header block and one merge block. If no blocks are present, or control
1202  // flow branches straight to merge block (two blocks are present), the op is
1203  // redundant and it is erased.
1204  if (op.getBody().getBlocks().size() <= 2) {
1205  rewriter.eraseOp(op);
1206  return success();
1207  }
1208 
1209  Location loc = op.getLoc();
1210 
1211  // Split the current block after `spirv.mlir.selection`. The remaining ops
1212  // will be used in `continueBlock`.
1213  auto *currentBlock = rewriter.getInsertionBlock();
1214  rewriter.setInsertionPointAfter(op);
1215  auto position = rewriter.getInsertionPoint();
1216  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1217 
1218  // Extract conditional branch information from the header block. By SPIR-V
1219  // dialect spec, it should contain `spirv.BranchConditional` or
1220  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1221  // moment in the SPIR-V dialect. Remove this block when finished.
1222  auto *headerBlock = op.getHeaderBlock();
1223  assert(headerBlock->getOperations().size() == 1);
1224  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1225  headerBlock->getOperations().front());
1226  if (!condBrOp)
1227  return failure();
1228  rewriter.eraseBlock(headerBlock);
1229 
1230  // Branch from merge block to continue block.
1231  auto *mergeBlock = op.getMergeBlock();
1232  Operation *terminator = mergeBlock->getTerminator();
1233  ValueRange terminatorOperands = terminator->getOperands();
1234  rewriter.setInsertionPointToEnd(mergeBlock);
1235  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1236 
1237  // Link current block to `true` and `false` blocks within the selection.
1238  Block *trueBlock = condBrOp.getTrueBlock();
1239  Block *falseBlock = condBrOp.getFalseBlock();
1240  rewriter.setInsertionPointToEnd(currentBlock);
1241  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1242  condBrOp.getTrueTargetOperands(),
1243  falseBlock,
1244  condBrOp.getFalseTargetOperands());
1245 
1246  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1247  rewriter.replaceOp(op, continueBlock->getArguments());
1248  return success();
1249  }
1250 };
1251 
1252 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1253 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1254 /// `Shift` is zero or sign extended to match this specification. Cases when
1255 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1256 template <typename SPIRVOp, typename LLVMOp>
1257 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1258 public:
1260 
1261  LogicalResult
1262  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1263  ConversionPatternRewriter &rewriter) const override {
1264 
1265  auto dstType = this->getTypeConverter()->convertType(op.getType());
1266  if (!dstType)
1267  return rewriter.notifyMatchFailure(op, "type conversion failed");
1268 
1269  Type op1Type = op.getOperand1().getType();
1270  Type op2Type = op.getOperand2().getType();
1271 
1272  if (op1Type == op2Type) {
1273  rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1274  adaptor.getOperands());
1275  return success();
1276  }
1277 
1278  std::optional<uint64_t> dstTypeWidth =
1280  std::optional<uint64_t> op2TypeWidth =
1282 
1283  if (!dstTypeWidth || !op2TypeWidth)
1284  return failure();
1285 
1286  Location loc = op.getLoc();
1287  Value extended;
1288  if (op2TypeWidth < dstTypeWidth) {
1289  if (isUnsignedIntegerOrVector(op2Type)) {
1290  extended = rewriter.template create<LLVM::ZExtOp>(
1291  loc, dstType, adaptor.getOperand2());
1292  } else {
1293  extended = rewriter.template create<LLVM::SExtOp>(
1294  loc, dstType, adaptor.getOperand2());
1295  }
1296  } else if (op2TypeWidth == dstTypeWidth) {
1297  extended = adaptor.getOperand2();
1298  } else {
1299  return failure();
1300  }
1301 
1302  Value result = rewriter.template create<LLVMOp>(
1303  loc, dstType, adaptor.getOperand1(), extended);
1304  rewriter.replaceOp(op, result);
1305  return success();
1306  }
1307 };
1308 
1309 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1310 public:
1312 
1313  LogicalResult
1314  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1315  ConversionPatternRewriter &rewriter) const override {
1316  auto dstType = getTypeConverter()->convertType(tanOp.getType());
1317  if (!dstType)
1318  return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1319 
1320  Location loc = tanOp.getLoc();
1321  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1322  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1323  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1324  return success();
1325  }
1326 };
1327 
1328 /// Convert `spirv.Tanh` to
1329 ///
1330 /// exp(2x) - 1
1331 /// -----------
1332 /// exp(2x) + 1
1333 ///
1334 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1335 public:
1337 
1338  LogicalResult
1339  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1340  ConversionPatternRewriter &rewriter) const override {
1341  auto srcType = tanhOp.getType();
1342  auto dstType = getTypeConverter()->convertType(srcType);
1343  if (!dstType)
1344  return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1345 
1346  Location loc = tanhOp.getLoc();
1347  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1348  Value multiplied =
1349  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1350  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1351  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1352  Value numerator =
1353  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1354  Value denominator =
1355  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1356  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1357  denominator);
1358  return success();
1359  }
1360 };
1361 
1362 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1363 public:
1365 
1366  LogicalResult
1367  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1368  ConversionPatternRewriter &rewriter) const override {
1369  auto srcType = varOp.getType();
1370  // Initialization is supported for scalars and vectors only.
1371  auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1372  auto init = varOp.getInitializer();
1373  if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1374  return failure();
1375 
1376  auto dstType = getTypeConverter()->convertType(srcType);
1377  if (!dstType)
1378  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1379 
1380  Location loc = varOp.getLoc();
1381  Value size = createI32ConstantOf(loc, rewriter, 1);
1382  if (!init) {
1383  auto elementType = getTypeConverter()->convertType(pointerTo);
1384  if (!elementType)
1385  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1386  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1387  size);
1388  return success();
1389  }
1390  auto elementType = getTypeConverter()->convertType(pointerTo);
1391  if (!elementType)
1392  return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1393  Value allocated =
1394  rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size);
1395  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1396  rewriter.replaceOp(varOp, allocated);
1397  return success();
1398  }
1399 };
1400 
1401 //===----------------------------------------------------------------------===//
1402 // BitcastOp conversion
1403 //===----------------------------------------------------------------------===//
1404 
1405 class BitcastConversionPattern
1406  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1407 public:
1409 
1410  LogicalResult
1411  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1412  ConversionPatternRewriter &rewriter) const override {
1413  auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1414  if (!dstType)
1415  return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1416 
1417  // LLVM's opaque pointers do not require bitcasts.
1418  if (isa<LLVM::LLVMPointerType>(dstType)) {
1419  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1420  return success();
1421  }
1422 
1423  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1424  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1425  return success();
1426  }
1427 };
1428 
1429 //===----------------------------------------------------------------------===//
1430 // FuncOp conversion
1431 //===----------------------------------------------------------------------===//
1432 
1433 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1434 public:
1436 
1437  LogicalResult
1438  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1439  ConversionPatternRewriter &rewriter) const override {
1440 
1441  // Convert function signature. At the moment LLVMType converter is enough
1442  // for currently supported types.
1443  auto funcType = funcOp.getFunctionType();
1444  TypeConverter::SignatureConversion signatureConverter(
1445  funcType.getNumInputs());
1446  auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1447  ->convertFunctionSignature(
1448  funcType, /*isVariadic=*/false,
1449  /*useBarePtrCallConv=*/false, signatureConverter);
1450  if (!llvmType)
1451  return failure();
1452 
1453  // Create a new `LLVMFuncOp`
1454  Location loc = funcOp.getLoc();
1455  StringRef name = funcOp.getName();
1456  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1457 
1458  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1459  MLIRContext *context = funcOp.getContext();
1460  switch (funcOp.getFunctionControl()) {
1461  case spirv::FunctionControl::Inline:
1462  newFuncOp.setAlwaysInline(true);
1463  break;
1464  case spirv::FunctionControl::DontInline:
1465  newFuncOp.setNoInline(true);
1466  break;
1467 
1468 #define DISPATCH(functionControl, llvmAttr) \
1469  case functionControl: \
1470  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1471  break;
1472 
1473  DISPATCH(spirv::FunctionControl::Pure,
1474  StringAttr::get(context, "readonly"));
1475  DISPATCH(spirv::FunctionControl::Const,
1476  StringAttr::get(context, "readnone"));
1477 
1478 #undef DISPATCH
1479 
1480  // Default: if `spirv::FunctionControl::None`, then no attributes are
1481  // needed.
1482  default:
1483  break;
1484  }
1485 
1486  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1487  newFuncOp.end());
1488  if (failed(rewriter.convertRegionTypes(
1489  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1490  return failure();
1491  }
1492  rewriter.eraseOp(funcOp);
1493  return success();
1494  }
1495 };
1496 
1497 //===----------------------------------------------------------------------===//
1498 // ModuleOp conversion
1499 //===----------------------------------------------------------------------===//
1500 
1501 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1502 public:
1504 
1505  LogicalResult
1506  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1507  ConversionPatternRewriter &rewriter) const override {
1508 
1509  auto newModuleOp =
1510  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1511  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1512 
1513  // Remove the terminator block that was automatically added by builder
1514  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1515  rewriter.eraseOp(spvModuleOp);
1516  return success();
1517  }
1518 };
1519 
1520 //===----------------------------------------------------------------------===//
1521 // VectorShuffleOp conversion
1522 //===----------------------------------------------------------------------===//
1523 
1524 class VectorShufflePattern
1525  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1526 public:
1528  LogicalResult
1529  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1530  ConversionPatternRewriter &rewriter) const override {
1531  Location loc = op.getLoc();
1532  auto components = adaptor.getComponents();
1533  auto vector1 = adaptor.getVector1();
1534  auto vector2 = adaptor.getVector2();
1535  int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1536  int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1537  if (vector1Size == vector2Size) {
1538  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1539  op, vector1, vector2,
1540  LLVM::convertArrayToIndices<int32_t>(components));
1541  return success();
1542  }
1543 
1544  auto dstType = getTypeConverter()->convertType(op.getType());
1545  if (!dstType)
1546  return rewriter.notifyMatchFailure(op, "type conversion failed");
1547  auto scalarType = cast<VectorType>(dstType).getElementType();
1548  auto componentsArray = components.getValue();
1549  auto *context = rewriter.getContext();
1550  auto llvmI32Type = IntegerType::get(context, 32);
1551  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1552  for (unsigned i = 0; i < componentsArray.size(); i++) {
1553  if (!isa<IntegerAttr>(componentsArray[i]))
1554  return op.emitError("unable to support non-constant component");
1555 
1556  int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1557  if (indexVal == -1)
1558  continue;
1559 
1560  int offsetVal = 0;
1561  Value baseVector = vector1;
1562  if (indexVal >= vector1Size) {
1563  offsetVal = vector1Size;
1564  baseVector = vector2;
1565  }
1566 
1567  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1568  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1569  Value index = rewriter.create<LLVM::ConstantOp>(
1570  loc, llvmI32Type,
1571  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1572 
1573  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1574  loc, scalarType, baseVector, index);
1575  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1576  extractOp, dstIndex);
1577  }
1578  rewriter.replaceOp(op, targetOp);
1579  return success();
1580  }
1581 };
1582 } // namespace
1583 
1584 //===----------------------------------------------------------------------===//
1585 // Pattern population
1586 //===----------------------------------------------------------------------===//
1587 
1589  spirv::ClientAPI clientAPI) {
1590  typeConverter.addConversion([&](spirv::ArrayType type) {
1591  return convertArrayType(type, typeConverter);
1592  });
1593  typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1594  return convertPointerType(type, typeConverter, clientAPI);
1595  });
1596  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1597  return convertRuntimeArrayType(type, typeConverter);
1598  });
1599  typeConverter.addConversion([&](spirv::StructType type) {
1600  return convertStructType(type, typeConverter);
1601  });
1602 }
1603 
1605  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1606  spirv::ClientAPI clientAPI) {
1607  patterns.add<
1608  // Arithmetic ops
1609  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1610  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1611  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1612  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1613  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1614  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1615  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1616  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1617  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1618  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1619  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1620  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1621  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1622 
1623  // Bitwise ops
1624  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1625  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1626  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1627  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1628  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1629  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1630  NotPattern<spirv::NotOp>,
1631 
1632  // Cast ops
1633  BitcastConversionPattern,
1634  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1635  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1636  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1637  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1638  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1639  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1640  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1641 
1642  // Comparison ops
1643  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1644  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1645  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1646  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1647  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1648  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1649  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1650  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1651  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1652  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1653  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1654  LLVM::FCmpPredicate::uge>,
1655  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1656  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1657  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1658  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1659  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1660  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1661  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1662  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1663  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1664  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1665  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1666 
1667  // Constant op
1668  ConstantScalarAndVectorPattern,
1669 
1670  // Control Flow ops
1671  BranchConversionPattern, BranchConditionalConversionPattern,
1672  FunctionCallPattern, LoopPattern, SelectionPattern,
1673  ErasePattern<spirv::MergeOp>,
1674 
1675  // Entry points and execution mode are handled separately.
1676  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1677 
1678  // GLSL extended instruction set ops
1679  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1680  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1681  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1682  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1683  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1684  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1685  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1686  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1687  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1688  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1689  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1690  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1691  InverseSqrtPattern, TanPattern, TanhPattern,
1692 
1693  // Logical ops
1694  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1695  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1696  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1697  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1698  NotPattern<spirv::LogicalNotOp>,
1699 
1700  // Memory ops
1701  AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1702  LoadStorePattern<spirv::StoreOp>, VariablePattern,
1703 
1704  // Miscellaneous ops
1705  CompositeExtractPattern, CompositeInsertPattern,
1706  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1707  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1708  VectorShufflePattern,
1709 
1710  // Shift ops
1711  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1712  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1713  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1714 
1715  // Return ops
1716  ReturnPattern, ReturnValuePattern,
1717 
1718  // Barrier ops
1719  ControlBarrierPattern>(patterns.getContext(), typeConverter);
1720 
1721  patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1722  typeConverter);
1723 }
1724 
1726  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1727  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1728 }
1729 
1731  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1732  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1733 }
1734 
1735 //===----------------------------------------------------------------------===//
1736 // Pre-conversion hooks
1737 //===----------------------------------------------------------------------===//
1738 
1739 /// Hook for descriptor set and binding number encoding.
1740 static constexpr StringRef kBinding = "binding";
1741 static constexpr StringRef kDescriptorSet = "descriptor_set";
1742 void mlir::encodeBindAttribute(ModuleOp module) {
1743  auto spvModules = module.getOps<spirv::ModuleOp>();
1744  for (auto spvModule : spvModules) {
1745  spvModule.walk([&](spirv::GlobalVariableOp op) {
1746  IntegerAttr descriptorSet =
1747  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1748  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1749  // For every global variable in the module, get the ones with descriptor
1750  // set and binding numbers.
1751  if (descriptorSet && binding) {
1752  // Encode these numbers into the variable's symbolic name. If the
1753  // SPIR-V module has a name, add it at the beginning.
1754  auto moduleAndName =
1755  spvModule.getName().has_value()
1756  ? spvModule.getName()->str() + "_" + op.getSymName().str()
1757  : op.getSymName().str();
1758  std::string name =
1759  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1760  std::to_string(descriptorSet.getInt()),
1761  std::to_string(binding.getInt()));
1762  auto nameAttr = StringAttr::get(op->getContext(), name);
1763 
1764  // Replace all symbol uses and set the new symbol name. Finally, remove
1765  // descriptor set and binding attributes.
1766  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1767  op.emitError("unable to replace all symbol uses for ") << name;
1768  SymbolTable::setSymbolName(op, nameAttr);
1770  op->removeAttr(kBinding);
1771  }
1772  });
1773  }
1774 }
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
static MLIRContext * getContext(OpFoldResult val)
@ None
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:66
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:37
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:46
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:87
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:97
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:79
#define DISPATCH(functionControl, llvmAttr)
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
Definition: SPIRVToLLVM.cpp:56
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
OpListType::iterator iterator
Definition: Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
iterator begin()
Definition: Block.h:141
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:99
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:452
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:456
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:595
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
Type getElementType() const
Definition: SPIRVTypes.cpp:66
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:64
StorageClass getStorageClass() const
Definition: SPIRVTypes.cpp:412
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:473
SPIR-V struct type.
Definition: SPIRVTypes.h:293
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:874
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:221
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:890
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.