MLIR  14.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38  if (type.isSignedInteger())
39  return true;
40  if (auto vecType = type.dyn_cast<VectorType>())
41  return vecType.getElementType().isSignedInteger();
42  return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47  if (type.isUnsignedInteger())
48  return true;
49  if (auto vecType = type.dyn_cast<VectorType>())
50  return vecType.getElementType().isUnsignedInteger();
51  return false;
52 }
53 
54 /// Returns the bit width of integer, float or vector of float or integer values
55 static unsigned getBitWidth(Type type) {
56  assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57  "bitwidth is not supported for this type");
58  if (type.isIntOrFloat())
59  return type.getIntOrFloatBitWidth();
60  auto vecType = type.dyn_cast<VectorType>();
61  auto elementType = vecType.getElementType();
62  assert(elementType.isIntOrFloat() &&
63  "only integers and floats have a bitwidth");
64  return elementType.getIntOrFloatBitWidth();
65 }
66 
67 /// Returns the bit width of LLVMType integer or vector.
68 static unsigned getLLVMTypeBitWidth(Type type) {
70  : type)
71  .cast<IntegerType>()
72  .getWidth();
73 }
74 
75 /// Creates `IntegerAttribute` with all bits set for given type
76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
77  if (auto vecType = type.dyn_cast<VectorType>()) {
78  auto integerType = vecType.getElementType().cast<IntegerType>();
79  return builder.getIntegerAttr(integerType, -1);
80  }
81  auto integerType = type.cast<IntegerType>();
82  return builder.getIntegerAttr(integerType, -1);
83 }
84 
85 /// Creates `llvm.mlir.constant` with all bits set for the given type.
86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
87  PatternRewriter &rewriter) {
88  if (srcType.isa<VectorType>()) {
89  return rewriter.create<LLVM::ConstantOp>(
90  loc, dstType,
91  SplatElementsAttr::get(srcType.cast<ShapedType>(),
92  minusOneIntegerAttribute(srcType, rewriter)));
93  }
94  return rewriter.create<LLVM::ConstantOp>(
95  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
96 }
97 
98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
99 static Value createFPConstant(Location loc, Type srcType, Type dstType,
100  PatternRewriter &rewriter, double value) {
101  if (auto vecType = srcType.dyn_cast<VectorType>()) {
102  auto floatType = vecType.getElementType().cast<FloatType>();
103  return rewriter.create<LLVM::ConstantOp>(
104  loc, dstType,
105  SplatElementsAttr::get(vecType,
106  rewriter.getFloatAttr(floatType, value)));
107  }
108  auto floatType = srcType.cast<FloatType>();
109  return rewriter.create<LLVM::ConstantOp>(
110  loc, dstType, rewriter.getFloatAttr(floatType, value));
111 }
112 
113 /// Utility function for bitfield ops:
114 /// - `BitFieldInsert`
115 /// - `BitFieldSExtract`
116 /// - `BitFieldUExtract`
117 /// Truncates or extends the value. If the bitwidth of the value is the same as
118 /// `llvmType` bitwidth, the value remains unchanged.
120  Type llvmType,
121  PatternRewriter &rewriter) {
122  auto srcType = value.getType();
123  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
124  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
125  ? getLLVMTypeBitWidth(srcType)
126  : getBitWidth(srcType);
127 
128  if (valueBitWidth < targetBitWidth)
129  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
130  // If the bit widths of `Count` and `Offset` are greater than the bit width
131  // of the target type, they are truncated. Truncation is safe since `Count`
132  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
133  // both values can be expressed in 8 bits.
134  if (valueBitWidth > targetBitWidth)
135  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
136  return value;
137 }
138 
139 /// Broadcasts the value to vector with `numElements` number of elements.
140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
141  LLVMTypeConverter &typeConverter,
142  ConversionPatternRewriter &rewriter) {
143  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
144  auto llvmVectorType = typeConverter.convertType(vectorType);
145  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
146  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
147  for (unsigned i = 0; i < numElements; ++i) {
148  auto index = rewriter.create<LLVM::ConstantOp>(
149  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
150  broadcasted = rewriter.create<LLVM::InsertElementOp>(
151  loc, llvmVectorType, broadcasted, toBroadcast, index);
152  }
153  return broadcasted;
154 }
155 
156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
158  LLVMTypeConverter &typeConverter,
159  ConversionPatternRewriter &rewriter) {
160  if (auto vectorType = srcType.dyn_cast<VectorType>()) {
161  unsigned numElements = vectorType.getNumElements();
162  return broadcast(loc, value, numElements, typeConverter, rewriter);
163  }
164  return value;
165 }
166 
167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
168 /// `BitFieldUExtract`.
169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
170 /// a vector type, construct a vector that has:
171 /// - same number of elements as `Base`
172 /// - each element has the type that is the same as the type of `Offset` or
173 /// `Count`
174 /// - each element has the same value as `Offset` or `Count`
175 /// Then cast `Offset` and `Count` if their bit width is different
176 /// from `Base` bit width.
178  Type dstType, LLVMTypeConverter &converter,
179  ConversionPatternRewriter &rewriter) {
180  Value broadcasted =
181  optionallyBroadcast(loc, value, srcType, converter, rewriter);
182  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
183 }
184 
185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
186 /// offset to LLVM struct. Otherwise, the conversion is not supported.
187 static Optional<Type>
189  LLVMTypeConverter &converter) {
190  if (type != VulkanLayoutUtils::decorateType(type))
191  return llvm::None;
192 
193  auto elementsVector = llvm::to_vector<8>(
194  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
195  return converter.convertType(elementType);
196  }));
197  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
198  /*isPacked=*/false);
199 }
200 
201 /// Converts SPIR-V struct with no offset to packed LLVM struct.
203  LLVMTypeConverter &converter) {
204  auto elementsVector = llvm::to_vector<8>(
205  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
206  return converter.convertType(elementType);
207  }));
208  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
209  /*isPacked=*/true);
210 }
211 
212 /// Creates LLVM dialect constant with the given value.
214  unsigned value) {
215  return rewriter.create<LLVM::ConstantOp>(
216  loc, IntegerType::get(rewriter.getContext(), 32),
217  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
218 }
219 
220 /// Utility for `spv.Load` and `spv.Store` conversion.
222  ConversionPatternRewriter &rewriter,
223  LLVMTypeConverter &typeConverter,
224  unsigned alignment, bool isVolatile,
225  bool isNonTemporal) {
226  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
227  auto dstType = typeConverter.convertType(loadOp.getType());
228  if (!dstType)
229  return failure();
230  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
231  loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
232  isVolatile, isNonTemporal);
233  return success();
234  }
235  auto storeOp = cast<spirv::StoreOp>(op);
236  spirv::StoreOpAdaptor adaptor(operands);
237  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
238  adaptor.ptr(), alignment,
239  isVolatile, isNonTemporal);
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Type conversion
245 //===----------------------------------------------------------------------===//
246 
247 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
248 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
249 /// when converting ops that manipulate array types.
251  TypeConverter &converter) {
252  unsigned stride = type.getArrayStride();
253  Type elementType = type.getElementType();
254  auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
255  if (stride != 0 &&
256  !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
257  return llvm::None;
258 
259  auto llvmElementType = converter.convertType(elementType);
260  unsigned numElements = type.getNumElements();
261  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
262 }
263 
264 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
265 /// modelled at the moment.
267  TypeConverter &converter) {
268  auto pointeeType = converter.convertType(type.getPointeeType());
269  return LLVM::LLVMPointerType::get(pointeeType);
270 }
271 
272 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
273 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
274 /// no modelling of array stride at the moment.
276  TypeConverter &converter) {
277  if (type.getArrayStride() != 0)
278  return llvm::None;
279  auto elementType = converter.convertType(type.getElementType());
280  return LLVM::LLVMArrayType::get(elementType, 0);
281 }
282 
283 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
284 /// member decorations. Also, only natural offset is supported.
286  LLVMTypeConverter &converter) {
288  type.getMemberDecorations(memberDecorations);
289  if (!memberDecorations.empty())
290  return llvm::None;
291  if (type.hasOffset())
292  return convertStructTypeWithOffset(type, converter);
293  return convertStructTypePacked(type, converter);
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // Operation conversion
298 //===----------------------------------------------------------------------===//
299 
300 namespace {
301 
302 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
303 public:
305 
307  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
308  ConversionPatternRewriter &rewriter) const override {
309  auto dstType = typeConverter.convertType(op.component_ptr().getType());
310  if (!dstType)
311  return failure();
312  // To use GEP we need to add a first 0 index to go through the pointer.
313  auto indices = llvm::to_vector<4>(adaptor.indices());
314  Type indexType = op.indices().front().getType();
315  auto llvmIndexType = typeConverter.convertType(indexType);
316  if (!llvmIndexType)
317  return failure();
318  Value zero = rewriter.create<LLVM::ConstantOp>(
319  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
320  indices.insert(indices.begin(), zero);
321  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
322  indices);
323  return success();
324  }
325 };
326 
327 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
328 public:
330 
332  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
333  ConversionPatternRewriter &rewriter) const override {
334  auto dstType = typeConverter.convertType(op.pointer().getType());
335  if (!dstType)
336  return failure();
337  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
338  return success();
339  }
340 };
341 
342 class BitFieldInsertPattern
343  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
344 public:
346 
348  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
349  ConversionPatternRewriter &rewriter) const override {
350  auto srcType = op.getType();
351  auto dstType = typeConverter.convertType(srcType);
352  if (!dstType)
353  return failure();
354  Location loc = op.getLoc();
355 
356  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
357  Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
358  typeConverter, rewriter);
359  Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
360  typeConverter, rewriter);
361 
362  // Create a mask with bits set outside [Offset, Offset + Count - 1].
363  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
364  Value maskShiftedByCount =
365  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
366  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
367  maskShiftedByCount, minusOne);
368  Value maskShiftedByCountAndOffset =
369  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
370  Value mask = rewriter.create<LLVM::XOrOp>(
371  loc, dstType, maskShiftedByCountAndOffset, minusOne);
372 
373  // Extract unchanged bits from the `Base` that are outside of
374  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
375  Value baseAndMask =
376  rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
377  Value insertShiftedByOffset =
378  rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
379  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
380  insertShiftedByOffset);
381  return success();
382  }
383 };
384 
385 /// Converts SPIR-V ConstantOp with scalar or vector type.
386 class ConstantScalarAndVectorPattern
387  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
388 public:
390 
392  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
393  ConversionPatternRewriter &rewriter) const override {
394  auto srcType = constOp.getType();
395  if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
396  return failure();
397 
398  auto dstType = typeConverter.convertType(srcType);
399  if (!dstType)
400  return failure();
401 
402  // SPIR-V constant can be a signed/unsigned integer, which has to be
403  // casted to signless integer when converting to LLVM dialect. Removing the
404  // sign bit may have unexpected behaviour. However, it is better to handle
405  // it case-by-case, given that the purpose of the conversion is not to
406  // cover all possible corner cases.
407  if (isSignedIntegerOrVector(srcType) ||
408  isUnsignedIntegerOrVector(srcType)) {
409  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
410 
411  if (srcType.isa<VectorType>()) {
412  auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
413  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
414  constOp, dstType,
415  dstElementsAttr.mapValues(
416  signlessType, [&](const APInt &value) { return value; }));
417  return success();
418  }
419  auto srcAttr = constOp.value().cast<IntegerAttr>();
420  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
421  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
422  return success();
423  }
424  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
425  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
426  return success();
427  }
428 };
429 
430 class BitFieldSExtractPattern
431  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
432 public:
434 
436  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
437  ConversionPatternRewriter &rewriter) const override {
438  auto srcType = op.getType();
439  auto dstType = typeConverter.convertType(srcType);
440  if (!dstType)
441  return failure();
442  Location loc = op.getLoc();
443 
444  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
445  Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
446  typeConverter, rewriter);
447  Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
448  typeConverter, rewriter);
449 
450  // Create a constant that holds the size of the `Base`.
451  IntegerType integerType;
452  if (auto vecType = srcType.dyn_cast<VectorType>())
453  integerType = vecType.getElementType().cast<IntegerType>();
454  else
455  integerType = srcType.cast<IntegerType>();
456 
457  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
458  Value size =
459  srcType.isa<VectorType>()
460  ? rewriter.create<LLVM::ConstantOp>(
461  loc, dstType,
462  SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
463  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
464 
465  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
466  // at Offset + Count - 1 is the most significant bit now.
467  Value countPlusOffset =
468  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
469  Value amountToShiftLeft =
470  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
471  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
472  loc, dstType, op.base(), amountToShiftLeft);
473 
474  // Shift the result right, filling the bits with the sign bit.
475  Value amountToShiftRight =
476  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
477  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
478  amountToShiftRight);
479  return success();
480  }
481 };
482 
483 class BitFieldUExtractPattern
484  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
485 public:
487 
489  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
490  ConversionPatternRewriter &rewriter) const override {
491  auto srcType = op.getType();
492  auto dstType = typeConverter.convertType(srcType);
493  if (!dstType)
494  return failure();
495  Location loc = op.getLoc();
496 
497  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
498  Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
499  typeConverter, rewriter);
500  Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
501  typeConverter, rewriter);
502 
503  // Create a mask with bits set at [0, Count - 1].
504  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
505  Value maskShiftedByCount =
506  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
507  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
508  minusOne);
509 
510  // Shift `Base` by `Offset` and apply the mask on it.
511  Value shiftedBase =
512  rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
513  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
514  return success();
515  }
516 };
517 
518 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
519 public:
521 
523  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
524  ConversionPatternRewriter &rewriter) const override {
525  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
526  branchOp.getTarget());
527  return success();
528  }
529 };
530 
531 class BranchConditionalConversionPattern
532  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
533 public:
534  using SPIRVToLLVMConversion<
535  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
536 
538  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
539  ConversionPatternRewriter &rewriter) const override {
540  // If branch weights exist, map them to 32-bit integer vector.
541  ElementsAttr branchWeights = nullptr;
542  if (auto weights = op.branch_weights()) {
543  VectorType weightType = VectorType::get(2, rewriter.getI32Type());
544  branchWeights =
545  DenseElementsAttr::get(weightType, weights.getValue().getValue());
546  }
547 
548  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
549  op, op.condition(), op.getTrueBlockArguments(),
550  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
551  op.getFalseBlock());
552  return success();
553  }
554 };
555 
556 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
557 /// is an aggregate type (struct or array). Otherwise, converts to
558 /// `llvm.extractelement` that operates on vectors.
559 class CompositeExtractPattern
560  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
561 public:
563 
565  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
566  ConversionPatternRewriter &rewriter) const override {
567  auto dstType = this->typeConverter.convertType(op.getType());
568  if (!dstType)
569  return failure();
570 
571  Type containerType = op.composite().getType();
572  if (containerType.isa<VectorType>()) {
573  Location loc = op.getLoc();
574  IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
575  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
576  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
577  op, dstType, adaptor.composite(), index);
578  return success();
579  }
580  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
581  op, dstType, adaptor.composite(), op.indices());
582  return success();
583  }
584 };
585 
586 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
587 /// is an aggregate type (struct or array). Otherwise, converts to
588 /// `llvm.insertelement` that operates on vectors.
589 class CompositeInsertPattern
590  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
591 public:
593 
595  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
596  ConversionPatternRewriter &rewriter) const override {
597  auto dstType = this->typeConverter.convertType(op.getType());
598  if (!dstType)
599  return failure();
600 
601  Type containerType = op.composite().getType();
602  if (containerType.isa<VectorType>()) {
603  Location loc = op.getLoc();
604  IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
605  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
606  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
607  op, dstType, adaptor.composite(), adaptor.object(), index);
608  return success();
609  }
610  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
611  op, dstType, adaptor.composite(), adaptor.object(), op.indices());
612  return success();
613  }
614 };
615 
616 /// Converts SPIR-V operations that have straightforward LLVM equivalent
617 /// into LLVM dialect operations.
618 template <typename SPIRVOp, typename LLVMOp>
619 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
620 public:
622 
624  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
625  ConversionPatternRewriter &rewriter) const override {
626  auto dstType = this->typeConverter.convertType(operation.getType());
627  if (!dstType)
628  return failure();
629  rewriter.template replaceOpWithNewOp<LLVMOp>(
630  operation, dstType, adaptor.getOperands(), operation->getAttrs());
631  return success();
632  }
633 };
634 
635 /// Converts `spv.ExecutionMode` into a global struct constant that holds
636 /// execution mode information.
637 class ExecutionModePattern
638  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
639 public:
641 
643  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
644  ConversionPatternRewriter &rewriter) const override {
645  // First, create the global struct's name that would be associated with
646  // this entry point's execution mode. We set it to be:
647  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
648  ModuleOp module = op->getParentOfType<ModuleOp>();
649  IntegerAttr executionModeAttr = op.execution_modeAttr();
650  std::string moduleName;
651  if (module.getName().hasValue())
652  moduleName = "_" + module.getName().getValue().str();
653  else
654  moduleName = "";
655  std::string executionModeInfoName =
656  llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
657  op.fn().str(), executionModeAttr.getValue());
658 
659  MLIRContext *context = rewriter.getContext();
660  OpBuilder::InsertionGuard guard(rewriter);
661  rewriter.setInsertionPointToStart(module.getBody());
662 
663  // Create a struct type, corresponding to the C struct below.
664  // struct {
665  // int32_t executionMode;
666  // int32_t values[]; // optional values
667  // };
668  auto llvmI32Type = IntegerType::get(context, 32);
669  SmallVector<Type, 2> fields;
670  fields.push_back(llvmI32Type);
671  ArrayAttr values = op.values();
672  if (!values.empty()) {
673  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
674  fields.push_back(arrayType);
675  }
676  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
677 
678  // Create `llvm.mlir.global` with initializer region containing one block.
679  auto global = rewriter.create<LLVM::GlobalOp>(
680  UnknownLoc::get(context), structType, /*isConstant=*/true,
681  LLVM::Linkage::External, executionModeInfoName, Attribute(),
682  /*alignment=*/0);
683  Location loc = global.getLoc();
684  Region &region = global.getInitializerRegion();
685  Block *block = rewriter.createBlock(&region);
686 
687  // Initialize the struct and set the execution mode value.
688  rewriter.setInsertionPoint(block, block->begin());
689  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
690  Value executionMode =
691  rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
692  structValue = rewriter.create<LLVM::InsertValueOp>(
693  loc, structType, structValue, executionMode,
694  ArrayAttr::get(context,
695  {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
696 
697  // Insert extra operands if they exist into execution mode info struct.
698  for (unsigned i = 0, e = values.size(); i < e; ++i) {
699  auto attr = values.getValue()[i];
700  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
701  structValue = rewriter.create<LLVM::InsertValueOp>(
702  loc, structType, structValue, entry,
703  ArrayAttr::get(context,
704  {rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
705  rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
706  }
707  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
708  rewriter.eraseOp(op);
709  return success();
710  }
711 };
712 
713 /// Converts `spv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V global
714 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
715 /// This difference is handled by `spv.mlir.addressof` and
716 /// `llvm.mlir.addressof`ops that both return a pointer.
717 class GlobalVariablePattern
718  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
719 public:
721 
723  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
724  ConversionPatternRewriter &rewriter) const override {
725  // Currently, there is no support of initialization with a constant value in
726  // SPIR-V dialect. Specialization constants are not considered as well.
727  if (op.initializer())
728  return failure();
729 
730  auto srcType = op.type().cast<spirv::PointerType>();
731  auto dstType = typeConverter.convertType(srcType.getPointeeType());
732  if (!dstType)
733  return failure();
734 
735  // Limit conversion to the current invocation only or `StorageBuffer`
736  // required by SPIR-V runner.
737  // This is okay because multiple invocations are not supported yet.
738  auto storageClass = srcType.getStorageClass();
739  switch (storageClass) {
740  case spirv::StorageClass::Input:
741  case spirv::StorageClass::Private:
742  case spirv::StorageClass::Output:
743  case spirv::StorageClass::StorageBuffer:
744  case spirv::StorageClass::UniformConstant:
745  break;
746  default:
747  return failure();
748  }
749 
750  // LLVM dialect spec: "If the global value is a constant, storing into it is
751  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
752  // storage class that is read-only.
753  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
754  (storageClass == spirv::StorageClass::UniformConstant);
755  // SPIR-V spec: "By default, functions and global variables are private to a
756  // module and cannot be accessed by other modules. However, a module may be
757  // written to export or import functions and global (module scope)
758  // variables.". Therefore, map 'Private' storage class to private linkage,
759  // 'Input' and 'Output' to external linkage.
760  auto linkage = storageClass == spirv::StorageClass::Private
761  ? LLVM::Linkage::Private
762  : LLVM::Linkage::External;
763  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
764  op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
765  /*alignment=*/0);
766 
767  // Attach location attribute if applicable
768  if (op.locationAttr())
769  newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
770 
771  return success();
772  }
773 };
774 
775 /// Converts SPIR-V cast ops that do not have straightforward LLVM
776 /// equivalent in LLVM dialect.
777 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
778 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
779 public:
781 
783  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
784  ConversionPatternRewriter &rewriter) const override {
785 
786  Type fromType = operation.operand().getType();
787  Type toType = operation.getType();
788 
789  auto dstType = this->typeConverter.convertType(toType);
790  if (!dstType)
791  return failure();
792 
793  if (getBitWidth(fromType) < getBitWidth(toType)) {
794  rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
795  adaptor.getOperands());
796  return success();
797  }
798  if (getBitWidth(fromType) > getBitWidth(toType)) {
799  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
800  adaptor.getOperands());
801  return success();
802  }
803  return failure();
804  }
805 };
806 
807 class FunctionCallPattern
808  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
809 public:
811 
813  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
814  ConversionPatternRewriter &rewriter) const override {
815  if (callOp.getNumResults() == 0) {
816  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
817  callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs());
818  return success();
819  }
820 
821  // Function returns a single result.
822  auto dstType = typeConverter.convertType(callOp.getType(0));
823  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
824  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
825  return success();
826  }
827 };
828 
829 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
830 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
831 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
832 public:
834 
836  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
837  ConversionPatternRewriter &rewriter) const override {
838 
839  auto dstType = this->typeConverter.convertType(operation.getType());
840  if (!dstType)
841  return failure();
842 
843  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
844  operation, dstType, predicate, operation.operand1(),
845  operation.operand2());
846  return success();
847  }
848 };
849 
850 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
851 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
852 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
853 public:
855 
857  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
858  ConversionPatternRewriter &rewriter) const override {
859 
860  auto dstType = this->typeConverter.convertType(operation.getType());
861  if (!dstType)
862  return failure();
863 
864  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
865  operation, dstType, predicate, operation.operand1(),
866  operation.operand2());
867  return success();
868  }
869 };
870 
871 class InverseSqrtPattern
872  : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
873 public:
875 
877  matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor,
878  ConversionPatternRewriter &rewriter) const override {
879  auto srcType = op.getType();
880  auto dstType = typeConverter.convertType(srcType);
881  if (!dstType)
882  return failure();
883 
884  Location loc = op.getLoc();
885  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
886  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
887  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
888  return success();
889  }
890 };
891 
892 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
893 template <typename SPIRVOp>
894 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
895 public:
897 
899  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
900  ConversionPatternRewriter &rewriter) const override {
901  if (!op.memory_access().hasValue()) {
902  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
903  this->typeConverter, /*alignment=*/0,
904  /*isVolatile=*/false,
905  /*isNonTemporal=*/false);
906  }
907  auto memoryAccess = op.memory_access().getValue();
908  switch (memoryAccess) {
909  case spirv::MemoryAccess::Aligned:
910  case spirv::MemoryAccess::None:
911  case spirv::MemoryAccess::Nontemporal:
912  case spirv::MemoryAccess::Volatile: {
913  unsigned alignment =
914  memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
915  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
916  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
917  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
918  this->typeConverter, alignment, isVolatile,
919  isNonTemporal);
920  }
921  default:
922  // There is no support of other memory access attributes.
923  return failure();
924  }
925  }
926 };
927 
928 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
929 template <typename SPIRVOp>
930 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
931 public:
933 
935  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
936  ConversionPatternRewriter &rewriter) const override {
937  auto srcType = notOp.getType();
938  auto dstType = this->typeConverter.convertType(srcType);
939  if (!dstType)
940  return failure();
941 
942  Location loc = notOp.getLoc();
943  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
944  auto mask = srcType.template isa<VectorType>()
945  ? rewriter.create<LLVM::ConstantOp>(
946  loc, dstType,
948  srcType.template cast<VectorType>(), minusOne))
949  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
950  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
951  notOp.operand(), mask);
952  return success();
953  }
954 };
955 
956 /// A template pattern that erases the given `SPIRVOp`.
957 template <typename SPIRVOp>
958 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
959 public:
961 
963  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
964  ConversionPatternRewriter &rewriter) const override {
965  rewriter.eraseOp(op);
966  return success();
967  }
968 };
969 
970 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
971 public:
973 
975  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
976  ConversionPatternRewriter &rewriter) const override {
977  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
978  ArrayRef<Value>());
979  return success();
980  }
981 };
982 
983 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
984 public:
986 
988  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
989  ConversionPatternRewriter &rewriter) const override {
990  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
991  adaptor.getOperands());
992  return success();
993  }
994 };
995 
996 /// Converts `spv.mlir.loop` to LLVM dialect. All blocks within selection should
997 /// be reachable for conversion to succeed. The structure of the loop in LLVM
998 /// dialect will be the following:
999 ///
1000 /// +------------------------------------+
1001 /// | <code before spv.mlir.loop> |
1002 /// | llvm.br ^header |
1003 /// +------------------------------------+
1004 /// |
1005 /// +----------------+ |
1006 /// | | |
1007 /// | V V
1008 /// | +------------------------------------+
1009 /// | | ^header: |
1010 /// | | <header code> |
1011 /// | | llvm.cond_br %cond, ^body, ^exit |
1012 /// | +------------------------------------+
1013 /// | |
1014 /// | |----------------------+
1015 /// | | |
1016 /// | V |
1017 /// | +------------------------------------+ |
1018 /// | | ^body: | |
1019 /// | | <body code> | |
1020 /// | | llvm.br ^continue | |
1021 /// | +------------------------------------+ |
1022 /// | | |
1023 /// | V |
1024 /// | +------------------------------------+ |
1025 /// | | ^continue: | |
1026 /// | | <continue code> | |
1027 /// | | llvm.br ^header | |
1028 /// | +------------------------------------+ |
1029 /// | | |
1030 /// +---------------+ +----------------------+
1031 /// |
1032 /// V
1033 /// +------------------------------------+
1034 /// | ^exit: |
1035 /// | llvm.br ^remaining |
1036 /// +------------------------------------+
1037 /// |
1038 /// V
1039 /// +------------------------------------+
1040 /// | ^remaining: |
1041 /// | <code after spv.mlir.loop> |
1042 /// +------------------------------------+
1043 ///
1044 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1045 public:
1047 
1049  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1050  ConversionPatternRewriter &rewriter) const override {
1051  // There is no support of loop control at the moment.
1052  if (loopOp.loop_control() != spirv::LoopControl::None)
1053  return failure();
1054 
1055  Location loc = loopOp.getLoc();
1056 
1057  // Split the current block after `spv.mlir.loop`. The remaining ops will be
1058  // used in `endBlock`.
1059  Block *currentBlock = rewriter.getBlock();
1060  auto position = Block::iterator(loopOp);
1061  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1062 
1063  // Remove entry block and create a branch in the current block going to the
1064  // header block.
1065  Block *entryBlock = loopOp.getEntryBlock();
1066  assert(entryBlock->getOperations().size() == 1);
1067  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1068  if (!brOp)
1069  return failure();
1070  Block *headerBlock = loopOp.getHeaderBlock();
1071  rewriter.setInsertionPointToEnd(currentBlock);
1072  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1073  rewriter.eraseBlock(entryBlock);
1074 
1075  // Branch from merge block to end block.
1076  Block *mergeBlock = loopOp.getMergeBlock();
1077  Operation *terminator = mergeBlock->getTerminator();
1078  ValueRange terminatorOperands = terminator->getOperands();
1079  rewriter.setInsertionPointToEnd(mergeBlock);
1080  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1081 
1082  rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1083  rewriter.replaceOp(loopOp, endBlock->getArguments());
1084  return success();
1085  }
1086 };
1087 
1088 /// Converts `spv.mlir.selection` with `spv.BranchConditional` in its header
1089 /// block. All blocks within selection should be reachable for conversion to
1090 /// succeed.
1091 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1092 public:
1094 
1096  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1097  ConversionPatternRewriter &rewriter) const override {
1098  // There is no support for `Flatten` or `DontFlatten` selection control at
1099  // the moment. This are just compiler hints and can be performed during the
1100  // optimization passes.
1101  if (op.selection_control() != spirv::SelectionControl::None)
1102  return failure();
1103 
1104  // `spv.mlir.selection` should have at least two blocks: one selection
1105  // header block and one merge block. If no blocks are present, or control
1106  // flow branches straight to merge block (two blocks are present), the op is
1107  // redundant and it is erased.
1108  if (op.body().getBlocks().size() <= 2) {
1109  rewriter.eraseOp(op);
1110  return success();
1111  }
1112 
1113  Location loc = op.getLoc();
1114 
1115  // Split the current block after `spv.mlir.selection`. The remaining ops
1116  // will be used in `continueBlock`.
1117  auto *currentBlock = rewriter.getInsertionBlock();
1118  rewriter.setInsertionPointAfter(op);
1119  auto position = rewriter.getInsertionPoint();
1120  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1121 
1122  // Extract conditional branch information from the header block. By SPIR-V
1123  // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1124  // op. Note that `spv.Switch op` is not supported at the moment in the
1125  // SPIR-V dialect. Remove this block when finished.
1126  auto *headerBlock = op.getHeaderBlock();
1127  assert(headerBlock->getOperations().size() == 1);
1128  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1129  headerBlock->getOperations().front());
1130  if (!condBrOp)
1131  return failure();
1132  rewriter.eraseBlock(headerBlock);
1133 
1134  // Branch from merge block to continue block.
1135  auto *mergeBlock = op.getMergeBlock();
1136  Operation *terminator = mergeBlock->getTerminator();
1137  ValueRange terminatorOperands = terminator->getOperands();
1138  rewriter.setInsertionPointToEnd(mergeBlock);
1139  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1140 
1141  // Link current block to `true` and `false` blocks within the selection.
1142  Block *trueBlock = condBrOp.getTrueBlock();
1143  Block *falseBlock = condBrOp.getFalseBlock();
1144  rewriter.setInsertionPointToEnd(currentBlock);
1145  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1146  condBrOp.trueTargetOperands(), falseBlock,
1147  condBrOp.falseTargetOperands());
1148 
1149  rewriter.inlineRegionBefore(op.body(), continueBlock);
1150  rewriter.replaceOp(op, continueBlock->getArguments());
1151  return success();
1152  }
1153 };
1154 
1155 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1156 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1157 /// `Shift` is zero or sign extended to match this specification. Cases when
1158 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1159 template <typename SPIRVOp, typename LLVMOp>
1160 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1161 public:
1163 
1165  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1166  ConversionPatternRewriter &rewriter) const override {
1167 
1168  auto dstType = this->typeConverter.convertType(operation.getType());
1169  if (!dstType)
1170  return failure();
1171 
1172  Type op1Type = operation.operand1().getType();
1173  Type op2Type = operation.operand2().getType();
1174 
1175  if (op1Type == op2Type) {
1176  rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1177  adaptor.getOperands());
1178  return success();
1179  }
1180 
1181  Location loc = operation.getLoc();
1182  Value extended;
1183  if (isUnsignedIntegerOrVector(op2Type)) {
1184  extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1185  adaptor.operand2());
1186  } else {
1187  extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1188  adaptor.operand2());
1189  }
1190  Value result = rewriter.template create<LLVMOp>(
1191  loc, dstType, adaptor.operand1(), extended);
1192  rewriter.replaceOp(operation, result);
1193  return success();
1194  }
1195 };
1196 
1197 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1198 public:
1200 
1202  matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor,
1203  ConversionPatternRewriter &rewriter) const override {
1204  auto dstType = typeConverter.convertType(tanOp.getType());
1205  if (!dstType)
1206  return failure();
1207 
1208  Location loc = tanOp.getLoc();
1209  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1210  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1211  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1212  return success();
1213  }
1214 };
1215 
1216 /// Convert `spv.Tanh` to
1217 ///
1218 /// exp(2x) - 1
1219 /// -----------
1220 /// exp(2x) + 1
1221 ///
1222 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1223 public:
1225 
1227  matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor,
1228  ConversionPatternRewriter &rewriter) const override {
1229  auto srcType = tanhOp.getType();
1230  auto dstType = typeConverter.convertType(srcType);
1231  if (!dstType)
1232  return failure();
1233 
1234  Location loc = tanhOp.getLoc();
1235  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1236  Value multiplied =
1237  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1238  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1239  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1240  Value numerator =
1241  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1242  Value denominator =
1243  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1244  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1245  denominator);
1246  return success();
1247  }
1248 };
1249 
1250 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1251 public:
1253 
1255  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1256  ConversionPatternRewriter &rewriter) const override {
1257  auto srcType = varOp.getType();
1258  // Initialization is supported for scalars and vectors only.
1259  auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1260  auto init = varOp.initializer();
1261  if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1262  return failure();
1263 
1264  auto dstType = typeConverter.convertType(srcType);
1265  if (!dstType)
1266  return failure();
1267 
1268  Location loc = varOp.getLoc();
1269  Value size = createI32ConstantOf(loc, rewriter, 1);
1270  if (!init) {
1271  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1272  return success();
1273  }
1274  Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1275  rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
1276  rewriter.replaceOp(varOp, allocated);
1277  return success();
1278  }
1279 };
1280 
1281 //===----------------------------------------------------------------------===//
1282 // FuncOp conversion
1283 //===----------------------------------------------------------------------===//
1284 
1285 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1286 public:
1288 
1290  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1291  ConversionPatternRewriter &rewriter) const override {
1292 
1293  // Convert function signature. At the moment LLVMType converter is enough
1294  // for currently supported types.
1295  auto funcType = funcOp.getType();
1296  TypeConverter::SignatureConversion signatureConverter(
1297  funcType.getNumInputs());
1298  auto llvmType = typeConverter.convertFunctionSignature(
1299  funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1300  if (!llvmType)
1301  return failure();
1302 
1303  // Create a new `LLVMFuncOp`
1304  Location loc = funcOp.getLoc();
1305  StringRef name = funcOp.getName();
1306  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1307 
1308  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1309  MLIRContext *context = funcOp.getContext();
1310  switch (funcOp.function_control()) {
1311 #define DISPATCH(functionControl, llvmAttr) \
1312  case functionControl: \
1313  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1314  break;
1315 
1316  DISPATCH(spirv::FunctionControl::Inline,
1317  StringAttr::get(context, "alwaysinline"));
1318  DISPATCH(spirv::FunctionControl::DontInline,
1319  StringAttr::get(context, "noinline"));
1320  DISPATCH(spirv::FunctionControl::Pure,
1321  StringAttr::get(context, "readonly"));
1322  DISPATCH(spirv::FunctionControl::Const,
1323  StringAttr::get(context, "readnone"));
1324 
1325 #undef DISPATCH
1326 
1327  // Default: if `spirv::FunctionControl::None`, then no attributes are
1328  // needed.
1329  default:
1330  break;
1331  }
1332 
1333  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1334  newFuncOp.end());
1335  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1336  &signatureConverter))) {
1337  return failure();
1338  }
1339  rewriter.eraseOp(funcOp);
1340  return success();
1341  }
1342 };
1343 
1344 //===----------------------------------------------------------------------===//
1345 // ModuleOp conversion
1346 //===----------------------------------------------------------------------===//
1347 
1348 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1349 public:
1351 
1353  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1354  ConversionPatternRewriter &rewriter) const override {
1355 
1356  auto newModuleOp =
1357  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1358  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1359 
1360  // Remove the terminator block that was automatically added by builder
1361  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1362  rewriter.eraseOp(spvModuleOp);
1363  return success();
1364  }
1365 };
1366 
1367 //===----------------------------------------------------------------------===//
1368 // VectorShuffleOp conversion
1369 //===----------------------------------------------------------------------===//
1370 
1371 class VectorShufflePattern
1372  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1373 public:
1376  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1377  ConversionPatternRewriter &rewriter) const override {
1378  Location loc = op.getLoc();
1379  auto components = adaptor.components();
1380  auto vector1 = adaptor.vector1();
1381  auto vector2 = adaptor.vector2();
1382  int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
1383  int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
1384  if (vector1Size == vector2Size) {
1385  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, vector1, vector2,
1386  components);
1387  return success();
1388  }
1389 
1390  auto dstType = typeConverter.convertType(op.getType());
1391  auto scalarType = dstType.cast<VectorType>().getElementType();
1392  auto componentsArray = components.getValue();
1393  auto *context = rewriter.getContext();
1394  auto llvmI32Type = IntegerType::get(context, 32);
1395  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1396  for (unsigned i = 0; i < componentsArray.size(); i++) {
1397  if (componentsArray[i].isa<IntegerAttr>())
1398  op.emitError("unable to support non-constant component");
1399 
1400  int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1401  if (indexVal == -1)
1402  continue;
1403 
1404  int offsetVal = 0;
1405  Value baseVector = vector1;
1406  if (indexVal >= vector1Size) {
1407  offsetVal = vector1Size;
1408  baseVector = vector2;
1409  }
1410 
1411  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1412  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1413  Value index = rewriter.create<LLVM::ConstantOp>(
1414  loc, llvmI32Type,
1415  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1416 
1417  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1418  loc, scalarType, baseVector, index);
1419  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1420  extractOp, dstIndex);
1421  }
1422  rewriter.replaceOp(op, targetOp);
1423  return success();
1424  }
1425 };
1426 } // namespace
1427 
1428 //===----------------------------------------------------------------------===//
1429 // Pattern population
1430 //===----------------------------------------------------------------------===//
1431 
1433  typeConverter.addConversion([&](spirv::ArrayType type) {
1434  return convertArrayType(type, typeConverter);
1435  });
1436  typeConverter.addConversion([&](spirv::PointerType type) {
1437  return convertPointerType(type, typeConverter);
1438  });
1439  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1440  return convertRuntimeArrayType(type, typeConverter);
1441  });
1442  typeConverter.addConversion([&](spirv::StructType type) {
1443  return convertStructType(type, typeConverter);
1444  });
1445 }
1446 
1448  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1449  patterns.add<
1450  // Arithmetic ops
1451  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1452  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1453  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1454  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1455  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1456  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1457  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1458  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1459  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1460  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1461  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1462  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1463  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1464 
1465  // Bitwise ops
1466  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1467  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1468  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1469  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1470  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1471  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1472  NotPattern<spirv::NotOp>,
1473 
1474  // Cast ops
1475  DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1476  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1477  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1478  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1479  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1480  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1481  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1482  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1483 
1484  // Comparison ops
1485  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1486  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1487  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1488  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1489  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1490  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1491  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1492  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1493  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1494  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1495  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1496  LLVM::FCmpPredicate::uge>,
1497  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1498  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1499  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1500  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1501  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1502  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1503  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1504  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1505  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1506  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1507  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1508 
1509  // Constant op
1510  ConstantScalarAndVectorPattern,
1511 
1512  // Control Flow ops
1513  BranchConversionPattern, BranchConditionalConversionPattern,
1514  FunctionCallPattern, LoopPattern, SelectionPattern,
1515  ErasePattern<spirv::MergeOp>,
1516 
1517  // Entry points and execution mode are handled separately.
1518  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1519 
1520  // GLSL extended instruction set ops
1521  DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1522  DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1523  DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1524  DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1525  DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1526  DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1527  DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1528  DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1529  DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1530  DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1531  DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1532  DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1533  InverseSqrtPattern, TanPattern, TanhPattern,
1534 
1535  // Logical ops
1536  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1537  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1538  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1539  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1540  NotPattern<spirv::LogicalNotOp>,
1541 
1542  // Memory ops
1543  AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1544  LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1545  VariablePattern,
1546 
1547  // Miscellaneous ops
1548  CompositeExtractPattern, CompositeInsertPattern,
1549  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1550  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1551  VectorShufflePattern,
1552 
1553  // Shift ops
1554  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1555  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1556  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1557 
1558  // Return ops
1559  ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1560 }
1561 
1563  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1564  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1565 }
1566 
1568  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1569  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1570 }
1571 
1572 //===----------------------------------------------------------------------===//
1573 // Pre-conversion hooks
1574 //===----------------------------------------------------------------------===//
1575 
1576 /// Hook for descriptor set and binding number encoding.
1577 static constexpr StringRef kBinding = "binding";
1578 static constexpr StringRef kDescriptorSet = "descriptor_set";
1579 void mlir::encodeBindAttribute(ModuleOp module) {
1580  auto spvModules = module.getOps<spirv::ModuleOp>();
1581  for (auto spvModule : spvModules) {
1582  spvModule.walk([&](spirv::GlobalVariableOp op) {
1583  IntegerAttr descriptorSet =
1584  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1585  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1586  // For every global variable in the module, get the ones with descriptor
1587  // set and binding numbers.
1588  if (descriptorSet && binding) {
1589  // Encode these numbers into the variable's symbolic name. If the
1590  // SPIR-V module has a name, add it at the beginning.
1591  auto moduleAndName = spvModule.getName().hasValue()
1592  ? spvModule.getName().getValue().str() + "_" +
1593  op.sym_name().str()
1594  : op.sym_name().str();
1595  std::string name =
1596  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1597  std::to_string(descriptorSet.getInt()),
1598  std::to_string(binding.getInt()));
1599  auto nameAttr = StringAttr::get(op->getContext(), name);
1600 
1601  // Replace all symbol uses and set the new symbol name. Finally, remove
1602  // descriptor set and binding attributes.
1603  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1604  op.emitError("unable to replace all symbol uses for ") << name;
1605  SymbolTable::setSymbolName(op, nameAttr);
1606  op->removeAttr(kDescriptorSet);
1607  op->removeAttr(kBinding);
1608  }
1609  });
1610  }
1611 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
iterator begin()
Definition: Block.h:134
U cast() const
Definition: Location.h:67
MLIRContext * getContext() const
Definition: Builders.h:54
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:373
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spv.Load and spv.Store conversion.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:371
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:64
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
Type getPointeeType() const
Definition: SPIRVTypes.cpp:395
static LLVMArrayType get(Type elementType, unsigned numElements)
Gets or creates an instance of LLVM dialect array type containing numElements of elementType, in the same context as elementType.
Definition: LLVMTypes.cpp:39
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Block represents an ordered list of Operations.
Definition: Block.h:29
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
OpListType & getOperations()
Definition: Block.h:128
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:60
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:87
static Optional< Type > convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:76
static LLVMPointerType get(Type pointee, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:165
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:61
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:37
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
static Optional< Type > convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
#define DISPATCH(functionControl, llvmAttr)
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:379
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides all of the information necessary to convert a type signature. ...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:752
OpListType::iterator iterator
Definition: Block.h:131
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
static Type convertPointerType(spirv::PointerType type, TypeConverter &converter)
Converts SPIR-V pointer type to LLVM pointer.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
Type getElementType() const
Definition: SPIRVTypes.cpp:62
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns that convert from SPIR-V to LLVM.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:55
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:49
BlockArgListType getArguments()
Definition: Block.h:76
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:68
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
static Optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
ElementTypeRange getElementTypes() const
Definition: SPIRVTypes.cpp:996
bool isa() const
Definition: Value.h:89
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
void addConversion(FnT &&callback)
Register a conversion function.
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:458
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
Type conversion class.
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:678
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect. ...
Definition: LLVMTypes.cpp:762
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: TypeConverter.h:30
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
SPIR-V struct type.
Definition: SPIRVTypes.h:278
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
U cast() const
Definition: Value.h:107
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void encodeBindAttribute(ModuleOp module)
Encodes global variable&#39;s descriptor set and binding into its name if they both exist.
static constexpr StringRef kDescriptorSet
static Optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol &#39;oldSymbol&#39; with the provided symbol &#39;newSymbol&#39; that...
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:353
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:86
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:376
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
bool isa() const
Definition: Types.h:234
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:46
This class provides an abstraction over the different types of ranges over Values.
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:778
MLIRContext * getContext() const
Definition: PatternMatch.h:906
IntegerType getI32Type()
Definition: Builders.cpp:54
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
Definition: SPIRVToLLVM.cpp:99
An attribute that represents a reference to a dense integer vector or tensor object.
U cast() const
Definition: Types.h:250