MLIR  19.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39  assert((tp.getRank() > 1) && "unlowerable vector type");
40  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41  tp.getScalableDims().take_back());
42 }
43 
44 // Helper that picks the proper sequence for inserting.
46  const LLVMTypeConverter &typeConverter, Location loc,
47  Value val1, Value val2, Type llvmType, int64_t rank,
48  int64_t pos) {
49  assert(rank > 0 && "0-D vector corner case should have been handled already");
50  if (rank == 1) {
51  auto idxType = rewriter.getIndexType();
52  auto constant = rewriter.create<LLVM::ConstantOp>(
53  loc, typeConverter.convertType(idxType),
54  rewriter.getIntegerAttr(idxType, pos));
55  return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
56  constant);
57  }
58  return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
59 }
60 
61 // Helper that picks the proper sequence for extracting.
63  const LLVMTypeConverter &typeConverter, Location loc,
64  Value val, Type llvmType, int64_t rank, int64_t pos) {
65  if (rank <= 1) {
66  auto idxType = rewriter.getIndexType();
67  auto constant = rewriter.create<LLVM::ConstantOp>(
68  loc, typeConverter.convertType(idxType),
69  rewriter.getIntegerAttr(idxType, pos));
70  return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
71  constant);
72  }
73  return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
74 }
75 
76 // Helper that returns data layout alignment of a memref.
78  MemRefType memrefType, unsigned &align) {
79  Type elementTy = typeConverter.convertType(memrefType.getElementType());
80  if (!elementTy)
81  return failure();
82 
83  // TODO: this should use the MLIR data layout when it becomes available and
84  // stop depending on translation.
85  llvm::LLVMContext llvmContext;
86  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
87  .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
88  return success();
89 }
90 
91 // Check if the last stride is non-unit and has a valid memory space.
92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
93  const LLVMTypeConverter &converter) {
94  if (!isLastMemrefDimUnitStride(memRefType))
95  return failure();
96  if (failed(converter.getMemRefAddressSpace(memRefType)))
97  return failure();
98  return success();
99 }
100 
101 // Add an index vector component to a base pointer.
103  const LLVMTypeConverter &typeConverter,
104  MemRefType memRefType, Value llvmMemref, Value base,
105  Value index, uint64_t vLen) {
106  assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107  "unsupported memref type");
108  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
109  auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
110  return rewriter.create<LLVM::GEPOp>(
111  loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
112  base, index);
113 }
114 
115 /// Convert `foldResult` into a Value. Integer attribute is converted to
116 /// an LLVM constant op.
117 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
118  OpFoldResult foldResult) {
119  if (auto attr = foldResult.dyn_cast<Attribute>()) {
120  auto intAttr = cast<IntegerAttr>(attr);
121  return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
122  }
123 
124  return foldResult.get<Value>();
125 }
126 
127 namespace {
128 
129 /// Trivial Vector to LLVM conversions
130 using VectorScaleOpConversion =
132 
133 /// Conversion pattern for a vector.bitcast.
134 class VectorBitCastOpConversion
135  : public ConvertOpToLLVMPattern<vector::BitCastOp> {
136 public:
138 
140  matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
141  ConversionPatternRewriter &rewriter) const override {
142  // Only 0-D and 1-D vectors can be lowered to LLVM.
143  VectorType resultTy = bitCastOp.getResultVectorType();
144  if (resultTy.getRank() > 1)
145  return failure();
146  Type newResultTy = typeConverter->convertType(resultTy);
147  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
148  adaptor.getOperands()[0]);
149  return success();
150  }
151 };
152 
153 /// Conversion pattern for a vector.matrix_multiply.
154 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
155 class VectorMatmulOpConversion
156  : public ConvertOpToLLVMPattern<vector::MatmulOp> {
157 public:
159 
161  matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
162  ConversionPatternRewriter &rewriter) const override {
163  rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
164  matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
165  adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
166  matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
167  return success();
168  }
169 };
170 
171 /// Conversion pattern for a vector.flat_transpose.
172 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
173 class VectorFlatTransposeOpConversion
174  : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
175 public:
177 
179  matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
180  ConversionPatternRewriter &rewriter) const override {
181  rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
182  transOp, typeConverter->convertType(transOp.getRes().getType()),
183  adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
184  return success();
185  }
186 };
187 
188 /// Overloaded utility that replaces a vector.load, vector.store,
189 /// vector.maskedload and vector.maskedstore with their respective LLVM
190 /// couterparts.
191 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
192  vector::LoadOpAdaptor adaptor,
193  VectorType vectorTy, Value ptr, unsigned align,
194  ConversionPatternRewriter &rewriter) {
195  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
196  /*volatile_=*/false,
197  loadOp.getNontemporal());
198 }
199 
200 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
201  vector::MaskedLoadOpAdaptor adaptor,
202  VectorType vectorTy, Value ptr, unsigned align,
203  ConversionPatternRewriter &rewriter) {
204  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
205  loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
206 }
207 
208 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
209  vector::StoreOpAdaptor adaptor,
210  VectorType vectorTy, Value ptr, unsigned align,
211  ConversionPatternRewriter &rewriter) {
212  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
213  ptr, align, /*volatile_=*/false,
214  storeOp.getNontemporal());
215 }
216 
217 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
218  vector::MaskedStoreOpAdaptor adaptor,
219  VectorType vectorTy, Value ptr, unsigned align,
220  ConversionPatternRewriter &rewriter) {
221  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
222  storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
223 }
224 
225 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
226 /// vector.maskedstore.
227 template <class LoadOrStoreOp>
228 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
229 public:
231 
233  matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
234  typename LoadOrStoreOp::Adaptor adaptor,
235  ConversionPatternRewriter &rewriter) const override {
236  // Only 1-D vectors can be lowered to LLVM.
237  VectorType vectorTy = loadOrStoreOp.getVectorType();
238  if (vectorTy.getRank() > 1)
239  return failure();
240 
241  auto loc = loadOrStoreOp->getLoc();
242  MemRefType memRefTy = loadOrStoreOp.getMemRefType();
243 
244  // Resolve alignment.
245  unsigned align;
246  if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
247  return failure();
248 
249  // Resolve address.
250  auto vtype = cast<VectorType>(
251  this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
252  Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
253  adaptor.getIndices(), rewriter);
254  replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
255  rewriter);
256  return success();
257  }
258 };
259 
260 /// Conversion pattern for a vector.gather.
261 class VectorGatherOpConversion
262  : public ConvertOpToLLVMPattern<vector::GatherOp> {
263 public:
265 
267  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
268  ConversionPatternRewriter &rewriter) const override {
269  MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
270  assert(memRefType && "The base should be bufferized");
271 
272  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
273  return failure();
274 
275  auto loc = gather->getLoc();
276 
277  // Resolve alignment.
278  unsigned align;
279  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
280  return failure();
281 
282  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
283  adaptor.getIndices(), rewriter);
284  Value base = adaptor.getBase();
285 
286  auto llvmNDVectorTy = adaptor.getIndexVec().getType();
287  // Handle the simple case of 1-D vector.
288  if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
289  auto vType = gather.getVectorType();
290  // Resolve address.
291  Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
292  memRefType, base, ptr, adaptor.getIndexVec(),
293  /*vLen=*/vType.getDimSize(0));
294  // Replace with the gather intrinsic.
295  rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
296  gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
297  adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
298  return success();
299  }
300 
301  const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
302  auto callback = [align, memRefType, base, ptr, loc, &rewriter,
303  &typeConverter](Type llvm1DVectorTy,
304  ValueRange vectorOperands) {
305  // Resolve address.
306  Value ptrs = getIndexedPtrs(
307  rewriter, loc, typeConverter, memRefType, base, ptr,
308  /*index=*/vectorOperands[0],
309  LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
310  // Create the gather intrinsic.
311  return rewriter.create<LLVM::masked_gather>(
312  loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
313  /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
314  };
315  SmallVector<Value> vectorOperands = {
316  adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
318  gather, vectorOperands, *getTypeConverter(), callback, rewriter);
319  }
320 };
321 
322 /// Conversion pattern for a vector.scatter.
323 class VectorScatterOpConversion
324  : public ConvertOpToLLVMPattern<vector::ScatterOp> {
325 public:
327 
329  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
330  ConversionPatternRewriter &rewriter) const override {
331  auto loc = scatter->getLoc();
332  MemRefType memRefType = scatter.getMemRefType();
333 
334  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
335  return failure();
336 
337  // Resolve alignment.
338  unsigned align;
339  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
340  return failure();
341 
342  // Resolve address.
343  VectorType vType = scatter.getVectorType();
344  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
345  adaptor.getIndices(), rewriter);
346  Value ptrs = getIndexedPtrs(
347  rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
348  ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
349 
350  // Replace with the scatter intrinsic.
351  rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
352  scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
353  rewriter.getI32IntegerAttr(align));
354  return success();
355  }
356 };
357 
358 /// Conversion pattern for a vector.expandload.
359 class VectorExpandLoadOpConversion
360  : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
361 public:
363 
365  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
366  ConversionPatternRewriter &rewriter) const override {
367  auto loc = expand->getLoc();
368  MemRefType memRefType = expand.getMemRefType();
369 
370  // Resolve address.
371  auto vtype = typeConverter->convertType(expand.getVectorType());
372  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
373  adaptor.getIndices(), rewriter);
374 
375  rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
376  expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
377  return success();
378  }
379 };
380 
381 /// Conversion pattern for a vector.compressstore.
382 class VectorCompressStoreOpConversion
383  : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
384 public:
386 
388  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
389  ConversionPatternRewriter &rewriter) const override {
390  auto loc = compress->getLoc();
391  MemRefType memRefType = compress.getMemRefType();
392 
393  // Resolve address.
394  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
395  adaptor.getIndices(), rewriter);
396 
397  rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
398  compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
399  return success();
400  }
401 };
402 
403 /// Reduction neutral classes for overloading.
404 class ReductionNeutralZero {};
405 class ReductionNeutralIntOne {};
406 class ReductionNeutralFPOne {};
407 class ReductionNeutralAllOnes {};
408 class ReductionNeutralSIntMin {};
409 class ReductionNeutralUIntMin {};
410 class ReductionNeutralSIntMax {};
411 class ReductionNeutralUIntMax {};
412 class ReductionNeutralFPMin {};
413 class ReductionNeutralFPMax {};
414 
415 /// Create the reduction neutral zero value.
416 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
417  ConversionPatternRewriter &rewriter,
418  Location loc, Type llvmType) {
419  return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
420  rewriter.getZeroAttr(llvmType));
421 }
422 
423 /// Create the reduction neutral integer one value.
424 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
425  ConversionPatternRewriter &rewriter,
426  Location loc, Type llvmType) {
427  return rewriter.create<LLVM::ConstantOp>(
428  loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
429 }
430 
431 /// Create the reduction neutral fp one value.
432 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
433  ConversionPatternRewriter &rewriter,
434  Location loc, Type llvmType) {
435  return rewriter.create<LLVM::ConstantOp>(
436  loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
437 }
438 
439 /// Create the reduction neutral all-ones value.
440 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
441  ConversionPatternRewriter &rewriter,
442  Location loc, Type llvmType) {
443  return rewriter.create<LLVM::ConstantOp>(
444  loc, llvmType,
445  rewriter.getIntegerAttr(
446  llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
447 }
448 
449 /// Create the reduction neutral signed int minimum value.
450 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
451  ConversionPatternRewriter &rewriter,
452  Location loc, Type llvmType) {
453  return rewriter.create<LLVM::ConstantOp>(
454  loc, llvmType,
455  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
456  llvmType.getIntOrFloatBitWidth())));
457 }
458 
459 /// Create the reduction neutral unsigned int minimum value.
460 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
461  ConversionPatternRewriter &rewriter,
462  Location loc, Type llvmType) {
463  return rewriter.create<LLVM::ConstantOp>(
464  loc, llvmType,
465  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
466  llvmType.getIntOrFloatBitWidth())));
467 }
468 
469 /// Create the reduction neutral signed int maximum value.
470 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
471  ConversionPatternRewriter &rewriter,
472  Location loc, Type llvmType) {
473  return rewriter.create<LLVM::ConstantOp>(
474  loc, llvmType,
475  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
476  llvmType.getIntOrFloatBitWidth())));
477 }
478 
479 /// Create the reduction neutral unsigned int maximum value.
480 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
481  ConversionPatternRewriter &rewriter,
482  Location loc, Type llvmType) {
483  return rewriter.create<LLVM::ConstantOp>(
484  loc, llvmType,
485  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
486  llvmType.getIntOrFloatBitWidth())));
487 }
488 
489 /// Create the reduction neutral fp minimum value.
490 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
491  ConversionPatternRewriter &rewriter,
492  Location loc, Type llvmType) {
493  auto floatType = cast<FloatType>(llvmType);
494  return rewriter.create<LLVM::ConstantOp>(
495  loc, llvmType,
496  rewriter.getFloatAttr(
497  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
498  /*Negative=*/false)));
499 }
500 
501 /// Create the reduction neutral fp maximum value.
502 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
503  ConversionPatternRewriter &rewriter,
504  Location loc, Type llvmType) {
505  auto floatType = cast<FloatType>(llvmType);
506  return rewriter.create<LLVM::ConstantOp>(
507  loc, llvmType,
508  rewriter.getFloatAttr(
509  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
510  /*Negative=*/true)));
511 }
512 
513 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
514 /// returns a new accumulator value using `ReductionNeutral`.
515 template <class ReductionNeutral>
516 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
517  Location loc, Type llvmType,
518  Value accumulator) {
519  if (accumulator)
520  return accumulator;
521 
522  return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
523  llvmType);
524 }
525 
526 /// Creates a constant value with the 1-D vector shape provided in `llvmType`.
527 /// This is used as effective vector length by some intrinsics supporting
528 /// dynamic vector lengths at runtime.
529 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
530  Location loc, Type llvmType) {
531  VectorType vType = cast<VectorType>(llvmType);
532  auto vShape = vType.getShape();
533  assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
534 
535  return rewriter.create<LLVM::ConstantOp>(
536  loc, rewriter.getI32Type(),
537  rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
538 }
539 
540 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
541 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
542 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
543 /// non-null.
544 template <class LLVMRedIntrinOp, class ScalarOp>
545 static Value createIntegerReductionArithmeticOpLowering(
546  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
547  Value vectorOperand, Value accumulator) {
548 
549  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
550 
551  if (accumulator)
552  result = rewriter.create<ScalarOp>(loc, accumulator, result);
553  return result;
554 }
555 
556 /// Helper method to lower a `vector.reduction` operation that performs
557 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
558 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
559 /// the accumulator value if non-null.
560 template <class LLVMRedIntrinOp>
561 static Value createIntegerReductionComparisonOpLowering(
562  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
563  Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
564  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
565  if (accumulator) {
566  Value cmp =
567  rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
568  result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
569  }
570  return result;
571 }
572 
573 namespace {
574 template <typename Source>
575 struct VectorToScalarMapper;
576 template <>
577 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
578  using Type = LLVM::MaximumOp;
579 };
580 template <>
581 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
582  using Type = LLVM::MinimumOp;
583 };
584 template <>
585 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
586  using Type = LLVM::MaxNumOp;
587 };
588 template <>
589 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
590  using Type = LLVM::MinNumOp;
591 };
592 } // namespace
593 
594 template <class LLVMRedIntrinOp>
595 static Value createFPReductionComparisonOpLowering(
596  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
597  Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
598  Value result =
599  rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
600 
601  if (accumulator) {
602  result =
603  rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
604  loc, result, accumulator);
605  }
606 
607  return result;
608 }
609 
610 /// Reduction neutral classes for overloading
611 class MaskNeutralFMaximum {};
612 class MaskNeutralFMinimum {};
613 
614 /// Get the mask neutral floating point maximum value
615 static llvm::APFloat
616 getMaskNeutralValue(MaskNeutralFMaximum,
617  const llvm::fltSemantics &floatSemantics) {
618  return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
619 }
620 /// Get the mask neutral floating point minimum value
621 static llvm::APFloat
622 getMaskNeutralValue(MaskNeutralFMinimum,
623  const llvm::fltSemantics &floatSemantics) {
624  return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
625 }
626 
627 /// Create the mask neutral floating point MLIR vector constant
628 template <typename MaskNeutral>
629 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
630  Location loc, Type llvmType,
631  Type vectorType) {
632  const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
633  auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
634  auto denseValue =
635  DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
636  return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
637 }
638 
639 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
640 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
641 /// `fmaximum`/`fminimum`.
642 /// More information: https://github.com/llvm/llvm-project/issues/64940
643 template <class LLVMRedIntrinOp, class MaskNeutral>
644 static Value
645 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
646  Location loc, Type llvmType,
647  Value vectorOperand, Value accumulator,
648  Value mask, LLVM::FastmathFlagsAttr fmf) {
649  const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
650  rewriter, loc, llvmType, vectorOperand.getType());
651  const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
652  loc, mask, vectorOperand, vectorMaskNeutral);
653  return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
654  rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
655 }
656 
657 template <class LLVMRedIntrinOp, class ReductionNeutral>
658 static Value
659 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
660  Type llvmType, Value vectorOperand,
661  Value accumulator, LLVM::FastmathFlagsAttr fmf) {
662  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
663  llvmType, accumulator);
664  return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
665  /*startValue=*/accumulator,
666  vectorOperand, fmf);
667 }
668 
669 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
670 /// that requires a start value. This start value format spans across fp
671 /// reductions without mask and all the masked reduction intrinsics.
672 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
673 static Value
674 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
675  Location loc, Type llvmType,
676  Value vectorOperand, Value accumulator) {
677  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
678  llvmType, accumulator);
679  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
680  /*startValue=*/accumulator,
681  vectorOperand);
682 }
683 
684 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
685 static Value lowerPredicatedReductionWithStartValue(
686  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
687  Value vectorOperand, Value accumulator, Value mask) {
688  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
689  llvmType, accumulator);
690  Value vectorLength =
691  createVectorLengthValue(rewriter, loc, vectorOperand.getType());
692  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
693  /*startValue=*/accumulator,
694  vectorOperand, mask, vectorLength);
695 }
696 
697 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
698  class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
699 static Value lowerPredicatedReductionWithStartValue(
700  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
701  Value vectorOperand, Value accumulator, Value mask) {
702  if (llvmType.isIntOrIndex())
703  return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
704  IntReductionNeutral>(
705  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
706 
707  // FP dispatch.
708  return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
709  FPReductionNeutral>(
710  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
711 }
712 
713 /// Conversion pattern for all vector reductions.
714 class VectorReductionOpConversion
715  : public ConvertOpToLLVMPattern<vector::ReductionOp> {
716 public:
717  explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
718  bool reassociateFPRed)
719  : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
720  reassociateFPReductions(reassociateFPRed) {}
721 
723  matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
724  ConversionPatternRewriter &rewriter) const override {
725  auto kind = reductionOp.getKind();
726  Type eltType = reductionOp.getDest().getType();
727  Type llvmType = typeConverter->convertType(eltType);
728  Value operand = adaptor.getVector();
729  Value acc = adaptor.getAcc();
730  Location loc = reductionOp.getLoc();
731 
732  if (eltType.isIntOrIndex()) {
733  // Integer reductions: add/mul/min/max/and/or/xor.
734  Value result;
735  switch (kind) {
736  case vector::CombiningKind::ADD:
737  result =
738  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
739  LLVM::AddOp>(
740  rewriter, loc, llvmType, operand, acc);
741  break;
742  case vector::CombiningKind::MUL:
743  result =
744  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
745  LLVM::MulOp>(
746  rewriter, loc, llvmType, operand, acc);
747  break;
749  result = createIntegerReductionComparisonOpLowering<
750  LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
751  LLVM::ICmpPredicate::ule);
752  break;
753  case vector::CombiningKind::MINSI:
754  result = createIntegerReductionComparisonOpLowering<
755  LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
756  LLVM::ICmpPredicate::sle);
757  break;
758  case vector::CombiningKind::MAXUI:
759  result = createIntegerReductionComparisonOpLowering<
760  LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
761  LLVM::ICmpPredicate::uge);
762  break;
763  case vector::CombiningKind::MAXSI:
764  result = createIntegerReductionComparisonOpLowering<
765  LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
766  LLVM::ICmpPredicate::sge);
767  break;
768  case vector::CombiningKind::AND:
769  result =
770  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
771  LLVM::AndOp>(
772  rewriter, loc, llvmType, operand, acc);
773  break;
774  case vector::CombiningKind::OR:
775  result =
776  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
777  LLVM::OrOp>(
778  rewriter, loc, llvmType, operand, acc);
779  break;
780  case vector::CombiningKind::XOR:
781  result =
782  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
783  LLVM::XOrOp>(
784  rewriter, loc, llvmType, operand, acc);
785  break;
786  default:
787  return failure();
788  }
789  rewriter.replaceOp(reductionOp, result);
790 
791  return success();
792  }
793 
794  if (!isa<FloatType>(eltType))
795  return failure();
796 
797  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
798  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
799  reductionOp.getContext(),
800  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
802  reductionOp.getContext(),
803  fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
804  : LLVM::FastmathFlags::none));
805 
806  // Floating-point reductions: add/mul/min/max
807  Value result;
808  if (kind == vector::CombiningKind::ADD) {
809  result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
810  ReductionNeutralZero>(
811  rewriter, loc, llvmType, operand, acc, fmf);
812  } else if (kind == vector::CombiningKind::MUL) {
813  result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
814  ReductionNeutralFPOne>(
815  rewriter, loc, llvmType, operand, acc, fmf);
816  } else if (kind == vector::CombiningKind::MINIMUMF) {
817  result =
818  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
819  rewriter, loc, llvmType, operand, acc, fmf);
820  } else if (kind == vector::CombiningKind::MAXIMUMF) {
821  result =
822  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
823  rewriter, loc, llvmType, operand, acc, fmf);
824  } else if (kind == vector::CombiningKind::MINNUMF) {
825  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
826  rewriter, loc, llvmType, operand, acc, fmf);
827  } else if (kind == vector::CombiningKind::MAXNUMF) {
828  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
829  rewriter, loc, llvmType, operand, acc, fmf);
830  } else
831  return failure();
832 
833  rewriter.replaceOp(reductionOp, result);
834  return success();
835  }
836 
837 private:
838  const bool reassociateFPReductions;
839 };
840 
841 /// Base class to convert a `vector.mask` operation while matching traits
842 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
843 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
844 /// method performs a second match against the maskable operation `MaskedOp`.
845 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
846 /// implemented by the concrete conversion classes. This method can match
847 /// against specific traits of the `vector.mask` and the maskable operation. It
848 /// must replace the `vector.mask` operation.
849 template <class MaskedOp>
850 class VectorMaskOpConversionBase
851  : public ConvertOpToLLVMPattern<vector::MaskOp> {
852 public:
854 
856  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
857  ConversionPatternRewriter &rewriter) const final {
858  // Match against the maskable operation kind.
859  auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
860  if (!maskedOp)
861  return failure();
862  return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
863  }
864 
865 protected:
866  virtual LogicalResult
867  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
868  vector::MaskableOpInterface maskableOp,
869  ConversionPatternRewriter &rewriter) const = 0;
870 };
871 
872 class MaskedReductionOpConversion
873  : public VectorMaskOpConversionBase<vector::ReductionOp> {
874 
875 public:
876  using VectorMaskOpConversionBase<
877  vector::ReductionOp>::VectorMaskOpConversionBase;
878 
879  LogicalResult matchAndRewriteMaskableOp(
880  vector::MaskOp maskOp, MaskableOpInterface maskableOp,
881  ConversionPatternRewriter &rewriter) const override {
882  auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
883  auto kind = reductionOp.getKind();
884  Type eltType = reductionOp.getDest().getType();
885  Type llvmType = typeConverter->convertType(eltType);
886  Value operand = reductionOp.getVector();
887  Value acc = reductionOp.getAcc();
888  Location loc = reductionOp.getLoc();
889 
890  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
891  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
892  reductionOp.getContext(),
893  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
894 
895  Value result;
896  switch (kind) {
897  case vector::CombiningKind::ADD:
898  result = lowerPredicatedReductionWithStartValue<
899  LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
900  ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
901  maskOp.getMask());
902  break;
903  case vector::CombiningKind::MUL:
904  result = lowerPredicatedReductionWithStartValue<
905  LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
906  ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
907  maskOp.getMask());
908  break;
910  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
911  ReductionNeutralUIntMax>(
912  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
913  break;
914  case vector::CombiningKind::MINSI:
915  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
916  ReductionNeutralSIntMax>(
917  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
918  break;
919  case vector::CombiningKind::MAXUI:
920  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
921  ReductionNeutralUIntMin>(
922  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
923  break;
924  case vector::CombiningKind::MAXSI:
925  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
926  ReductionNeutralSIntMin>(
927  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
928  break;
929  case vector::CombiningKind::AND:
930  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
931  ReductionNeutralAllOnes>(
932  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
933  break;
934  case vector::CombiningKind::OR:
935  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
936  ReductionNeutralZero>(
937  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
938  break;
939  case vector::CombiningKind::XOR:
940  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
941  ReductionNeutralZero>(
942  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
943  break;
944  case vector::CombiningKind::MINNUMF:
945  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
946  ReductionNeutralFPMax>(
947  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
948  break;
949  case vector::CombiningKind::MAXNUMF:
950  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
951  ReductionNeutralFPMin>(
952  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
953  break;
954  case CombiningKind::MAXIMUMF:
955  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
956  MaskNeutralFMaximum>(
957  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
958  break;
959  case CombiningKind::MINIMUMF:
960  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
961  MaskNeutralFMinimum>(
962  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
963  break;
964  }
965 
966  // Replace `vector.mask` operation altogether.
967  rewriter.replaceOp(maskOp, result);
968  return success();
969  }
970 };
971 
972 class VectorShuffleOpConversion
973  : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
974 public:
976 
978  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
979  ConversionPatternRewriter &rewriter) const override {
980  auto loc = shuffleOp->getLoc();
981  auto v1Type = shuffleOp.getV1VectorType();
982  auto v2Type = shuffleOp.getV2VectorType();
983  auto vectorType = shuffleOp.getResultVectorType();
984  Type llvmType = typeConverter->convertType(vectorType);
985  auto maskArrayAttr = shuffleOp.getMask();
986 
987  // Bail if result type cannot be lowered.
988  if (!llvmType)
989  return failure();
990 
991  // Get rank and dimension sizes.
992  int64_t rank = vectorType.getRank();
993 #ifndef NDEBUG
994  bool wellFormed0DCase =
995  v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
996  bool wellFormedNDCase =
997  v1Type.getRank() == rank && v2Type.getRank() == rank;
998  assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
999 #endif
1000 
1001  // For rank 0 and 1, where both operands have *exactly* the same vector
1002  // type, there is direct shuffle support in LLVM. Use it!
1003  if (rank <= 1 && v1Type == v2Type) {
1004  Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1005  loc, adaptor.getV1(), adaptor.getV2(),
1006  LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1007  rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1008  return success();
1009  }
1010 
1011  // For all other cases, insert the individual values individually.
1012  int64_t v1Dim = v1Type.getDimSize(0);
1013  Type eltType;
1014  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1015  eltType = arrayType.getElementType();
1016  else
1017  eltType = cast<VectorType>(llvmType).getElementType();
1018  Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1019  int64_t insPos = 0;
1020  for (const auto &en : llvm::enumerate(maskArrayAttr)) {
1021  int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1022  Value value = adaptor.getV1();
1023  if (extPos >= v1Dim) {
1024  extPos -= v1Dim;
1025  value = adaptor.getV2();
1026  }
1027  Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1028  eltType, rank, extPos);
1029  insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1030  llvmType, rank, insPos++);
1031  }
1032  rewriter.replaceOp(shuffleOp, insert);
1033  return success();
1034  }
1035 };
1036 
1037 class VectorExtractElementOpConversion
1038  : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1039 public:
1040  using ConvertOpToLLVMPattern<
1041  vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1042 
1044  matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1045  ConversionPatternRewriter &rewriter) const override {
1046  auto vectorType = extractEltOp.getSourceVectorType();
1047  auto llvmType = typeConverter->convertType(vectorType.getElementType());
1048 
1049  // Bail if result type cannot be lowered.
1050  if (!llvmType)
1051  return failure();
1052 
1053  if (vectorType.getRank() == 0) {
1054  Location loc = extractEltOp.getLoc();
1055  auto idxType = rewriter.getIndexType();
1056  auto zero = rewriter.create<LLVM::ConstantOp>(
1057  loc, typeConverter->convertType(idxType),
1058  rewriter.getIntegerAttr(idxType, 0));
1059  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1060  extractEltOp, llvmType, adaptor.getVector(), zero);
1061  return success();
1062  }
1063 
1064  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1065  extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1066  return success();
1067  }
1068 };
1069 
1070 class VectorExtractOpConversion
1071  : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1072 public:
1074 
1076  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1077  ConversionPatternRewriter &rewriter) const override {
1078  auto loc = extractOp->getLoc();
1079  auto resultType = extractOp.getResult().getType();
1080  auto llvmResultType = typeConverter->convertType(resultType);
1081  // Bail if result type cannot be lowered.
1082  if (!llvmResultType)
1083  return failure();
1084 
1086  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1087 
1088  // Extract entire vector. Should be handled by folder, but just to be safe.
1089  ArrayRef<OpFoldResult> position(positionVec);
1090  if (position.empty()) {
1091  rewriter.replaceOp(extractOp, adaptor.getVector());
1092  return success();
1093  }
1094 
1095  // One-shot extraction of vector from array (only requires extractvalue).
1096  if (isa<VectorType>(resultType)) {
1097  if (extractOp.hasDynamicPosition())
1098  return failure();
1099 
1100  Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1101  loc, adaptor.getVector(), getAsIntegers(position));
1102  rewriter.replaceOp(extractOp, extracted);
1103  return success();
1104  }
1105 
1106  // Potential extraction of 1-D vector from array.
1107  Value extracted = adaptor.getVector();
1108  if (position.size() > 1) {
1109  if (extractOp.hasDynamicPosition())
1110  return failure();
1111 
1112  SmallVector<int64_t> nMinusOnePosition =
1113  getAsIntegers(position.drop_back());
1114  extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1115  nMinusOnePosition);
1116  }
1117 
1118  Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1119  // Remaining extraction of element from 1-D LLVM vector.
1120  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1121  lastPosition);
1122  return success();
1123  }
1124 };
1125 
1126 /// Conversion pattern that turns a vector.fma on a 1-D vector
1127 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1128 /// This does not match vectors of n >= 2 rank.
1129 ///
1130 /// Example:
1131 /// ```
1132 /// vector.fma %a, %a, %a : vector<8xf32>
1133 /// ```
1134 /// is converted to:
1135 /// ```
1136 /// llvm.intr.fmuladd %va, %va, %va:
1137 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1138 /// -> !llvm."<8 x f32>">
1139 /// ```
1140 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1141 public:
1143 
1145  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1146  ConversionPatternRewriter &rewriter) const override {
1147  VectorType vType = fmaOp.getVectorType();
1148  if (vType.getRank() > 1)
1149  return failure();
1150 
1151  rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1152  fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1153  return success();
1154  }
1155 };
1156 
1157 class VectorInsertElementOpConversion
1158  : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1159 public:
1161 
1163  matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1164  ConversionPatternRewriter &rewriter) const override {
1165  auto vectorType = insertEltOp.getDestVectorType();
1166  auto llvmType = typeConverter->convertType(vectorType);
1167 
1168  // Bail if result type cannot be lowered.
1169  if (!llvmType)
1170  return failure();
1171 
1172  if (vectorType.getRank() == 0) {
1173  Location loc = insertEltOp.getLoc();
1174  auto idxType = rewriter.getIndexType();
1175  auto zero = rewriter.create<LLVM::ConstantOp>(
1176  loc, typeConverter->convertType(idxType),
1177  rewriter.getIntegerAttr(idxType, 0));
1178  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1179  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1180  return success();
1181  }
1182 
1183  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1184  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1185  adaptor.getPosition());
1186  return success();
1187  }
1188 };
1189 
1190 class VectorInsertOpConversion
1191  : public ConvertOpToLLVMPattern<vector::InsertOp> {
1192 public:
1194 
1196  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1197  ConversionPatternRewriter &rewriter) const override {
1198  auto loc = insertOp->getLoc();
1199  auto sourceType = insertOp.getSourceType();
1200  auto destVectorType = insertOp.getDestVectorType();
1201  auto llvmResultType = typeConverter->convertType(destVectorType);
1202  // Bail if result type cannot be lowered.
1203  if (!llvmResultType)
1204  return failure();
1205 
1207  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1208 
1209  // Overwrite entire vector with value. Should be handled by folder, but
1210  // just to be safe.
1211  ArrayRef<OpFoldResult> position(positionVec);
1212  if (position.empty()) {
1213  rewriter.replaceOp(insertOp, adaptor.getSource());
1214  return success();
1215  }
1216 
1217  // One-shot insertion of a vector into an array (only requires insertvalue).
1218  if (isa<VectorType>(sourceType)) {
1219  if (insertOp.hasDynamicPosition())
1220  return failure();
1221 
1222  Value inserted = rewriter.create<LLVM::InsertValueOp>(
1223  loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1224  rewriter.replaceOp(insertOp, inserted);
1225  return success();
1226  }
1227 
1228  // Potential extraction of 1-D vector from array.
1229  Value extracted = adaptor.getDest();
1230  auto oneDVectorType = destVectorType;
1231  if (position.size() > 1) {
1232  if (insertOp.hasDynamicPosition())
1233  return failure();
1234 
1235  oneDVectorType = reducedVectorTypeBack(destVectorType);
1236  extracted = rewriter.create<LLVM::ExtractValueOp>(
1237  loc, extracted, getAsIntegers(position.drop_back()));
1238  }
1239 
1240  // Insertion of an element into a 1-D LLVM vector.
1241  Value inserted = rewriter.create<LLVM::InsertElementOp>(
1242  loc, typeConverter->convertType(oneDVectorType), extracted,
1243  adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1244 
1245  // Potential insertion of resulting 1-D vector into array.
1246  if (position.size() > 1) {
1247  if (insertOp.hasDynamicPosition())
1248  return failure();
1249 
1250  inserted = rewriter.create<LLVM::InsertValueOp>(
1251  loc, adaptor.getDest(), inserted,
1252  getAsIntegers(position.drop_back()));
1253  }
1254 
1255  rewriter.replaceOp(insertOp, inserted);
1256  return success();
1257  }
1258 };
1259 
1260 /// Lower vector.scalable.insert ops to LLVM vector.insert
1261 struct VectorScalableInsertOpLowering
1262  : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1263  using ConvertOpToLLVMPattern<
1264  vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1265 
1267  matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1268  ConversionPatternRewriter &rewriter) const override {
1269  rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1270  insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1271  return success();
1272  }
1273 };
1274 
1275 /// Lower vector.scalable.extract ops to LLVM vector.extract
1276 struct VectorScalableExtractOpLowering
1277  : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1278  using ConvertOpToLLVMPattern<
1279  vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1280 
1282  matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1283  ConversionPatternRewriter &rewriter) const override {
1284  rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1285  extOp, typeConverter->convertType(extOp.getResultVectorType()),
1286  adaptor.getSource(), adaptor.getPos());
1287  return success();
1288  }
1289 };
1290 
1291 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1292 ///
1293 /// Example:
1294 /// ```
1295 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1296 /// ```
1297 /// is rewritten into:
1298 /// ```
1299 /// %r = splat %f0: vector<2x4xf32>
1300 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1301 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1302 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1303 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1304 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1305 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1306 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1307 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1308 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1309 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1310 /// // %r3 holds the final value.
1311 /// ```
1312 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1313 public:
1315 
1316  void initialize() {
1317  // This pattern recursively unpacks one dimension at a time. The recursion
1318  // bounded as the rank is strictly decreasing.
1319  setHasBoundedRewriteRecursion();
1320  }
1321 
1322  LogicalResult matchAndRewrite(FMAOp op,
1323  PatternRewriter &rewriter) const override {
1324  auto vType = op.getVectorType();
1325  if (vType.getRank() < 2)
1326  return failure();
1327 
1328  auto loc = op.getLoc();
1329  auto elemType = vType.getElementType();
1330  Value zero = rewriter.create<arith::ConstantOp>(
1331  loc, elemType, rewriter.getZeroAttr(elemType));
1332  Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1333  for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1334  Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1335  Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1336  Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1337  Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1338  desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1339  }
1340  rewriter.replaceOp(op, desc);
1341  return success();
1342  }
1343 };
1344 
1345 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1346 /// static layout.
1347 static std::optional<SmallVector<int64_t, 4>>
1348 computeContiguousStrides(MemRefType memRefType) {
1349  int64_t offset;
1350  SmallVector<int64_t, 4> strides;
1351  if (failed(getStridesAndOffset(memRefType, strides, offset)))
1352  return std::nullopt;
1353  if (!strides.empty() && strides.back() != 1)
1354  return std::nullopt;
1355  // If no layout or identity layout, this is contiguous by definition.
1356  if (memRefType.getLayout().isIdentity())
1357  return strides;
1358 
1359  // Otherwise, we must determine contiguity form shapes. This can only ever
1360  // work in static cases because MemRefType is underspecified to represent
1361  // contiguous dynamic shapes in other ways than with just empty/identity
1362  // layout.
1363  auto sizes = memRefType.getShape();
1364  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1365  if (ShapedType::isDynamic(sizes[index + 1]) ||
1366  ShapedType::isDynamic(strides[index]) ||
1367  ShapedType::isDynamic(strides[index + 1]))
1368  return std::nullopt;
1369  if (strides[index] != strides[index + 1] * sizes[index + 1])
1370  return std::nullopt;
1371  }
1372  return strides;
1373 }
1374 
1375 class VectorTypeCastOpConversion
1376  : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1377 public:
1379 
1381  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1382  ConversionPatternRewriter &rewriter) const override {
1383  auto loc = castOp->getLoc();
1384  MemRefType sourceMemRefType =
1385  cast<MemRefType>(castOp.getOperand().getType());
1386  MemRefType targetMemRefType = castOp.getType();
1387 
1388  // Only static shape casts supported atm.
1389  if (!sourceMemRefType.hasStaticShape() ||
1390  !targetMemRefType.hasStaticShape())
1391  return failure();
1392 
1393  auto llvmSourceDescriptorTy =
1394  dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1395  if (!llvmSourceDescriptorTy)
1396  return failure();
1397  MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1398 
1399  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1400  typeConverter->convertType(targetMemRefType));
1401  if (!llvmTargetDescriptorTy)
1402  return failure();
1403 
1404  // Only contiguous source buffers supported atm.
1405  auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1406  if (!sourceStrides)
1407  return failure();
1408  auto targetStrides = computeContiguousStrides(targetMemRefType);
1409  if (!targetStrides)
1410  return failure();
1411  // Only support static strides for now, regardless of contiguity.
1412  if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1413  return failure();
1414 
1415  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1416 
1417  // Create descriptor.
1418  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1419  // Set allocated ptr.
1420  Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1421  desc.setAllocatedPtr(rewriter, loc, allocated);
1422 
1423  // Set aligned ptr.
1424  Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1425  desc.setAlignedPtr(rewriter, loc, ptr);
1426  // Fill offset 0.
1427  auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1428  auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1429  desc.setOffset(rewriter, loc, zero);
1430 
1431  // Fill size and stride descriptors in memref.
1432  for (const auto &indexedSize :
1433  llvm::enumerate(targetMemRefType.getShape())) {
1434  int64_t index = indexedSize.index();
1435  auto sizeAttr =
1436  rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1437  auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1438  desc.setSize(rewriter, loc, index, size);
1439  auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1440  (*targetStrides)[index]);
1441  auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1442  desc.setStride(rewriter, loc, index, stride);
1443  }
1444 
1445  rewriter.replaceOp(castOp, {desc});
1446  return success();
1447  }
1448 };
1449 
1450 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1451 /// Non-scalable versions of this operation are handled in Vector Transforms.
1452 class VectorCreateMaskOpRewritePattern
1453  : public OpRewritePattern<vector::CreateMaskOp> {
1454 public:
1455  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1456  bool enableIndexOpt)
1457  : OpRewritePattern<vector::CreateMaskOp>(context),
1458  force32BitVectorIndices(enableIndexOpt) {}
1459 
1460  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1461  PatternRewriter &rewriter) const override {
1462  auto dstType = op.getType();
1463  if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1464  return failure();
1465  IntegerType idxType =
1466  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1467  auto loc = op->getLoc();
1468  Value indices = rewriter.create<LLVM::StepVectorOp>(
1469  loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1470  /*isScalable=*/true));
1471  auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1472  op.getOperand(0));
1473  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1474  Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1475  indices, bounds);
1476  rewriter.replaceOp(op, comp);
1477  return success();
1478  }
1479 
1480 private:
1481  const bool force32BitVectorIndices;
1482 };
1483 
1484 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1485 public:
1487 
1488  // Lowering implementation that relies on a small runtime support library,
1489  // which only needs to provide a few printing methods (single value for all
1490  // data types, opening/closing bracket, comma, newline). The lowering splits
1491  // the vector into elementary printing operations. The advantage of this
1492  // approach is that the library can remain unaware of all low-level
1493  // implementation details of vectors while still supporting output of any
1494  // shaped and dimensioned vector.
1495  //
1496  // Note: This lowering only handles scalars, n-D vectors are broken into
1497  // printing scalars in loops in VectorToSCF.
1498  //
1499  // TODO: rely solely on libc in future? something else?
1500  //
1502  matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1503  ConversionPatternRewriter &rewriter) const override {
1504  auto parent = printOp->getParentOfType<ModuleOp>();
1505  if (!parent)
1506  return failure();
1507 
1508  auto loc = printOp->getLoc();
1509 
1510  if (auto value = adaptor.getSource()) {
1511  Type printType = printOp.getPrintType();
1512  if (isa<VectorType>(printType)) {
1513  // Vectors should be broken into elementary print ops in VectorToSCF.
1514  return failure();
1515  }
1516  if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1517  return failure();
1518  }
1519 
1520  auto punct = printOp.getPunctuation();
1521  if (auto stringLiteral = printOp.getStringLiteral()) {
1522  LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1523  *stringLiteral, *getTypeConverter(),
1524  /*addNewline=*/false);
1525  } else if (punct != PrintPunctuation::NoPunctuation) {
1526  emitCall(rewriter, printOp->getLoc(), [&] {
1527  switch (punct) {
1528  case PrintPunctuation::Close:
1529  return LLVM::lookupOrCreatePrintCloseFn(parent);
1530  case PrintPunctuation::Open:
1531  return LLVM::lookupOrCreatePrintOpenFn(parent);
1532  case PrintPunctuation::Comma:
1533  return LLVM::lookupOrCreatePrintCommaFn(parent);
1534  case PrintPunctuation::NewLine:
1535  return LLVM::lookupOrCreatePrintNewlineFn(parent);
1536  default:
1537  llvm_unreachable("unexpected punctuation");
1538  }
1539  }());
1540  }
1541 
1542  rewriter.eraseOp(printOp);
1543  return success();
1544  }
1545 
1546 private:
1547  enum class PrintConversion {
1548  // clang-format off
1549  None,
1550  ZeroExt64,
1551  SignExt64,
1552  Bitcast16
1553  // clang-format on
1554  };
1555 
1556  LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1557  ModuleOp parent, Location loc, Type printType,
1558  Value value) const {
1559  if (typeConverter->convertType(printType) == nullptr)
1560  return failure();
1561 
1562  // Make sure element type has runtime support.
1563  PrintConversion conversion = PrintConversion::None;
1564  Operation *printer;
1565  if (printType.isF32()) {
1566  printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1567  } else if (printType.isF64()) {
1568  printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1569  } else if (printType.isF16()) {
1570  conversion = PrintConversion::Bitcast16; // bits!
1571  printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1572  } else if (printType.isBF16()) {
1573  conversion = PrintConversion::Bitcast16; // bits!
1574  printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1575  } else if (printType.isIndex()) {
1576  printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1577  } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1578  // Integers need a zero or sign extension on the operand
1579  // (depending on the source type) as well as a signed or
1580  // unsigned print method. Up to 64-bit is supported.
1581  unsigned width = intTy.getWidth();
1582  if (intTy.isUnsigned()) {
1583  if (width <= 64) {
1584  if (width < 64)
1585  conversion = PrintConversion::ZeroExt64;
1586  printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1587  } else {
1588  return failure();
1589  }
1590  } else {
1591  assert(intTy.isSignless() || intTy.isSigned());
1592  if (width <= 64) {
1593  // Note that we *always* zero extend booleans (1-bit integers),
1594  // so that true/false is printed as 1/0 rather than -1/0.
1595  if (width == 1)
1596  conversion = PrintConversion::ZeroExt64;
1597  else if (width < 64)
1598  conversion = PrintConversion::SignExt64;
1599  printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1600  } else {
1601  return failure();
1602  }
1603  }
1604  } else {
1605  return failure();
1606  }
1607 
1608  switch (conversion) {
1609  case PrintConversion::ZeroExt64:
1610  value = rewriter.create<arith::ExtUIOp>(
1611  loc, IntegerType::get(rewriter.getContext(), 64), value);
1612  break;
1613  case PrintConversion::SignExt64:
1614  value = rewriter.create<arith::ExtSIOp>(
1615  loc, IntegerType::get(rewriter.getContext(), 64), value);
1616  break;
1617  case PrintConversion::Bitcast16:
1618  value = rewriter.create<LLVM::BitcastOp>(
1619  loc, IntegerType::get(rewriter.getContext(), 16), value);
1620  break;
1621  case PrintConversion::None:
1622  break;
1623  }
1624  emitCall(rewriter, loc, printer, value);
1625  return success();
1626  }
1627 
1628  // Helper to emit a call.
1629  static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1630  Operation *ref, ValueRange params = ValueRange()) {
1631  rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1632  params);
1633  }
1634 };
1635 
1636 /// The Splat operation is lowered to an insertelement + a shufflevector
1637 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1638 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1640 
1642  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1643  ConversionPatternRewriter &rewriter) const override {
1644  VectorType resultType = cast<VectorType>(splatOp.getType());
1645  if (resultType.getRank() > 1)
1646  return failure();
1647 
1648  // First insert it into an undef vector so we can shuffle it.
1649  auto vectorType = typeConverter->convertType(splatOp.getType());
1650  Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1651  auto zero = rewriter.create<LLVM::ConstantOp>(
1652  splatOp.getLoc(),
1653  typeConverter->convertType(rewriter.getIntegerType(32)),
1654  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1655 
1656  // For 0-d vector, we simply do `insertelement`.
1657  if (resultType.getRank() == 0) {
1658  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1659  splatOp, vectorType, undef, adaptor.getInput(), zero);
1660  return success();
1661  }
1662 
1663  // For 1-d vector, we additionally do a `vectorshuffle`.
1664  auto v = rewriter.create<LLVM::InsertElementOp>(
1665  splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1666 
1667  int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1668  SmallVector<int32_t> zeroValues(width, 0);
1669 
1670  // Shuffle the value across the desired number of elements.
1671  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1672  zeroValues);
1673  return success();
1674  }
1675 };
1676 
1677 /// The Splat operation is lowered to an insertelement + a shufflevector
1678 /// operation. Splat to only 2+-d vector result types are lowered by the
1679 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1680 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1682 
1684  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1685  ConversionPatternRewriter &rewriter) const override {
1686  VectorType resultType = splatOp.getType();
1687  if (resultType.getRank() <= 1)
1688  return failure();
1689 
1690  // First insert it into an undef vector so we can shuffle it.
1691  auto loc = splatOp.getLoc();
1692  auto vectorTypeInfo =
1693  LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1694  auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1695  auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1696  if (!llvmNDVectorTy || !llvm1DVectorTy)
1697  return failure();
1698 
1699  // Construct returned value.
1700  Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1701 
1702  // Construct a 1-D vector with the splatted value that we insert in all the
1703  // places within the returned descriptor.
1704  Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1705  auto zero = rewriter.create<LLVM::ConstantOp>(
1706  loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1707  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1708  Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1709  adaptor.getInput(), zero);
1710 
1711  // Shuffle the value across the desired number of elements.
1712  int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1713  SmallVector<int32_t> zeroValues(width, 0);
1714  v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1715 
1716  // Iterate of linear index, convert to coords space and insert splatted 1-D
1717  // vector in each position.
1718  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1719  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1720  });
1721  rewriter.replaceOp(splatOp, desc);
1722  return success();
1723  }
1724 };
1725 
1726 /// Conversion pattern for a `vector.interleave`.
1727 /// This supports fixed-sized vectors and scalable vectors.
1728 struct VectorInterleaveOpLowering
1729  : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1731 
1733  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1734  ConversionPatternRewriter &rewriter) const override {
1735  VectorType resultType = interleaveOp.getResultVectorType();
1736  // n-D interleaves should have been lowered already.
1737  if (resultType.getRank() != 1)
1738  return rewriter.notifyMatchFailure(interleaveOp,
1739  "InterleaveOp not rank 1");
1740  // If the result is rank 1, then this directly maps to LLVM.
1741  if (resultType.isScalable()) {
1742  rewriter.replaceOpWithNewOp<LLVM::experimental_vector_interleave2>(
1743  interleaveOp, typeConverter->convertType(resultType),
1744  adaptor.getLhs(), adaptor.getRhs());
1745  return success();
1746  }
1747  // Lower fixed-size interleaves to a shufflevector. While the
1748  // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1749  // langref still recommends fixed-vectors use shufflevector, see:
1750  // https://llvm.org/docs/LangRef.html#id876.
1751  int64_t resultVectorSize = resultType.getNumElements();
1752  SmallVector<int32_t> interleaveShuffleMask;
1753  interleaveShuffleMask.reserve(resultVectorSize);
1754  for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1755  interleaveShuffleMask.push_back(i);
1756  interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1757  }
1758  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1759  interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1760  interleaveShuffleMask);
1761  return success();
1762  }
1763 };
1764 
1765 } // namespace
1766 
1767 /// Populate the given list with patterns that convert from Vector to LLVM.
1769  LLVMTypeConverter &converter, RewritePatternSet &patterns,
1770  bool reassociateFPReductions, bool force32BitVectorIndices) {
1771  MLIRContext *ctx = converter.getDialect()->getContext();
1772  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1774  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1775  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1776  patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1777  VectorExtractElementOpConversion, VectorExtractOpConversion,
1778  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1779  VectorInsertOpConversion, VectorPrintOpConversion,
1780  VectorTypeCastOpConversion, VectorScaleOpConversion,
1781  VectorLoadStoreConversion<vector::LoadOp>,
1782  VectorLoadStoreConversion<vector::MaskedLoadOp>,
1783  VectorLoadStoreConversion<vector::StoreOp>,
1784  VectorLoadStoreConversion<vector::MaskedStoreOp>,
1785  VectorGatherOpConversion, VectorScatterOpConversion,
1786  VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1787  VectorSplatOpLowering, VectorSplatNdOpLowering,
1788  VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1789  MaskedReductionOpConversion, VectorInterleaveOpLowering>(
1790  converter);
1791  // Transfer ops with rank > 1 are handled by VectorToSCF.
1792  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1793 }
1794 
1796  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1797  patterns.add<VectorMatmulOpConversion>(converter);
1798  patterns.add<VectorFlatTransposeOpConversion>(converter);
1799 }
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, uint64_t vLen)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
@ None
#define MINUI(lhs, rhs)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:19
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Definition: TypeConverter.h:92
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition: TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
Definition: TypeToLLVM.cpp:196
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:198
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:340
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:115
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:928
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:955
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp)
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:901
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:284
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:50
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358