MLIR  22.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/TypeUtilities.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/Support/Casting.h"
34 
35 #include <optional>
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 // Helper that picks the proper sequence for inserting.
42  const LLVMTypeConverter &typeConverter, Location loc,
43  Value val1, Value val2, Type llvmType, int64_t rank,
44  int64_t pos) {
45  assert(rank > 0 && "0-D vector corner case should have been handled already");
46  if (rank == 1) {
47  auto idxType = rewriter.getIndexType();
48  auto constant = LLVM::ConstantOp::create(
49  rewriter, loc, typeConverter.convertType(idxType),
50  rewriter.getIntegerAttr(idxType, pos));
51  return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
52  constant);
53  }
54  return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
55 }
56 
57 // Helper that picks the proper sequence for extracting.
59  const LLVMTypeConverter &typeConverter, Location loc,
60  Value val, Type llvmType, int64_t rank, int64_t pos) {
61  if (rank <= 1) {
62  auto idxType = rewriter.getIndexType();
63  auto constant = LLVM::ConstantOp::create(
64  rewriter, loc, typeConverter.convertType(idxType),
65  rewriter.getIntegerAttr(idxType, pos));
66  return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
67  constant);
68  }
69  return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
70 }
71 
72 // Helper that returns data layout alignment of a vector.
73 LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
74  VectorType vectorType, unsigned &align) {
75  Type convertedVectorTy = typeConverter.convertType(vectorType);
76  if (!convertedVectorTy)
77  return failure();
78 
79  llvm::LLVMContext llvmContext;
80  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81  .getPreferredAlignment(convertedVectorTy,
82  typeConverter.getDataLayout());
83 
84  return success();
85 }
86 
87 // Helper that returns data layout alignment of a memref.
88 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
89  MemRefType memrefType, unsigned &align) {
90  Type elementTy = typeConverter.convertType(memrefType.getElementType());
91  if (!elementTy)
92  return failure();
93 
94  // TODO: this should use the MLIR data layout when it becomes available and
95  // stop depending on translation.
96  llvm::LLVMContext llvmContext;
97  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
98  .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
99  return success();
100 }
101 
102 // Helper to resolve the alignment for vector load/store, gather and scatter
103 // ops. If useVectorAlignment is true, get the preferred alignment for the
104 // vector type in the operation. This option is used for hardware backends with
105 // vectorization. Otherwise, use the preferred alignment of the element type of
106 // the memref. Note that if you choose to use vector alignment, the shape of the
107 // vector type must be resolved before the ConvertVectorToLLVM pass is run.
108 LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
109  VectorType vectorType,
110  MemRefType memrefType, unsigned &align,
111  bool useVectorAlignment) {
112  if (useVectorAlignment) {
113  if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
114  return failure();
115  }
116  } else {
117  if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
118  return failure();
119  }
120  }
121  return success();
122 }
123 
124 // Check if the last stride is non-unit and has a valid memory space.
125 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
126  const LLVMTypeConverter &converter) {
127  if (!memRefType.isLastDimUnitStride())
128  return failure();
129  if (failed(converter.getMemRefAddressSpace(memRefType)))
130  return failure();
131  return success();
132 }
133 
134 // Add an index vector component to a base pointer.
136  const LLVMTypeConverter &typeConverter,
137  MemRefType memRefType, Value llvmMemref, Value base,
138  Value index, VectorType vectorType) {
139  assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
140  "unsupported memref type");
141  assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
142  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
143  auto ptrsType =
144  LLVM::getVectorType(pType, vectorType.getDimSize(0),
145  /*isScalable=*/vectorType.getScalableDims()[0]);
146  return LLVM::GEPOp::create(
147  rewriter, loc, ptrsType,
148  typeConverter.convertType(memRefType.getElementType()), base, index);
149 }
150 
151 /// Convert `foldResult` into a Value. Integer attribute is converted to
152 /// an LLVM constant op.
153 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
154  OpFoldResult foldResult) {
155  if (auto attr = dyn_cast<Attribute>(foldResult)) {
156  auto intAttr = cast<IntegerAttr>(attr);
157  return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
158  }
159 
160  return cast<Value>(foldResult);
161 }
162 
163 namespace {
164 
165 /// Trivial Vector to LLVM conversions
166 using VectorScaleOpConversion =
168 
169 /// Conversion pattern for a vector.bitcast.
170 class VectorBitCastOpConversion
171  : public ConvertOpToLLVMPattern<vector::BitCastOp> {
172 public:
174 
175  LogicalResult
176  matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  // Only 0-D and 1-D vectors can be lowered to LLVM.
179  VectorType resultTy = bitCastOp.getResultVectorType();
180  if (resultTy.getRank() > 1)
181  return failure();
182  Type newResultTy = typeConverter->convertType(resultTy);
183  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184  adaptor.getOperands()[0]);
185  return success();
186  }
187 };
188 
189 /// Overloaded utility that replaces a vector.load, vector.store,
190 /// vector.maskedload and vector.maskedstore with their respective LLVM
191 /// couterparts.
192 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193  vector::LoadOpAdaptor adaptor,
194  VectorType vectorTy, Value ptr, unsigned align,
195  ConversionPatternRewriter &rewriter) {
196  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
197  /*volatile_=*/false,
198  loadOp.getNontemporal());
199 }
200 
201 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202  vector::MaskedLoadOpAdaptor adaptor,
203  VectorType vectorTy, Value ptr, unsigned align,
204  ConversionPatternRewriter &rewriter) {
205  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206  loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
207 }
208 
209 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210  vector::StoreOpAdaptor adaptor,
211  VectorType vectorTy, Value ptr, unsigned align,
212  ConversionPatternRewriter &rewriter) {
213  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
214  ptr, align, /*volatile_=*/false,
215  storeOp.getNontemporal());
216 }
217 
218 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219  vector::MaskedStoreOpAdaptor adaptor,
220  VectorType vectorTy, Value ptr, unsigned align,
221  ConversionPatternRewriter &rewriter) {
222  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223  storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
224 }
225 
226 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
227 /// vector.maskedstore.
228 template <class LoadOrStoreOp>
229 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
230 public:
231  explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
232  bool useVectorAlign)
233  : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234  useVectorAlignment(useVectorAlign) {}
236 
237  LogicalResult
238  matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239  typename LoadOrStoreOp::Adaptor adaptor,
240  ConversionPatternRewriter &rewriter) const override {
241  // Only 1-D vectors can be lowered to LLVM.
242  VectorType vectorTy = loadOrStoreOp.getVectorType();
243  if (vectorTy.getRank() > 1)
244  return failure();
245 
246  auto loc = loadOrStoreOp->getLoc();
247  MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248 
249  // Resolve alignment.
250  unsigned align = loadOrStoreOp.getAlignment().value_or(0);
251  if (!align &&
252  failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
253  memRefTy, align, useVectorAlignment)))
254  return rewriter.notifyMatchFailure(loadOrStoreOp,
255  "could not resolve alignment");
256 
257  // Resolve address.
258  auto vtype = cast<VectorType>(
259  this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
260  Value dataPtr = this->getStridedElementPtr(
261  rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
262  replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
263  rewriter);
264  return success();
265  }
266 
267 private:
268  // If true, use the preferred alignment of the vector type.
269  // If false, use the preferred alignment of the element type
270  // of the memref. This flag is intended for use with hardware
271  // backends that require alignment of vector operations.
272  const bool useVectorAlignment;
273 };
274 
275 /// Conversion pattern for a vector.gather.
276 class VectorGatherOpConversion
277  : public ConvertOpToLLVMPattern<vector::GatherOp> {
278 public:
279  explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
280  bool useVectorAlign)
281  : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
282  useVectorAlignment(useVectorAlign) {}
284 
285  LogicalResult
286  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
287  ConversionPatternRewriter &rewriter) const override {
288  Location loc = gather->getLoc();
289  MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
290  assert(memRefType && "The base should be bufferized");
291 
292  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
293  return rewriter.notifyMatchFailure(gather, "memref type not supported");
294 
295  VectorType vType = gather.getVectorType();
296  if (vType.getRank() > 1) {
297  return rewriter.notifyMatchFailure(
298  gather, "only 1-D vectors can be lowered to LLVM");
299  }
300 
301  // Resolve alignment.
302  unsigned align;
303  if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
304  memRefType, align, useVectorAlignment)))
305  return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
306 
307  // Resolve address.
308  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
309  adaptor.getBase(), adaptor.getOffsets());
310  Value base = adaptor.getBase();
311  Value ptrs =
312  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
313  base, ptr, adaptor.getIndices(), vType);
314 
315  // Replace with the gather intrinsic.
316  rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
317  gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
318  adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
319  return success();
320  }
321 
322 private:
323  // If true, use the preferred alignment of the vector type.
324  // If false, use the preferred alignment of the element type
325  // of the memref. This flag is intended for use with hardware
326  // backends that require alignment of vector operations.
327  const bool useVectorAlignment;
328 };
329 
330 /// Conversion pattern for a vector.scatter.
331 class VectorScatterOpConversion
332  : public ConvertOpToLLVMPattern<vector::ScatterOp> {
333 public:
334  explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
335  bool useVectorAlign)
336  : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
337  useVectorAlignment(useVectorAlign) {}
338 
340 
341  LogicalResult
342  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344  auto loc = scatter->getLoc();
345  MemRefType memRefType = scatter.getMemRefType();
346 
347  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
348  return rewriter.notifyMatchFailure(scatter, "memref type not supported");
349 
350  VectorType vType = scatter.getVectorType();
351  if (vType.getRank() > 1) {
352  return rewriter.notifyMatchFailure(
353  scatter, "only 1-D vectors can be lowered to LLVM");
354  }
355 
356  // Resolve alignment.
357  unsigned align;
358  if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
359  memRefType, align, useVectorAlignment)))
360  return rewriter.notifyMatchFailure(scatter,
361  "could not resolve alignment");
362 
363  // Resolve address.
364  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
365  adaptor.getBase(), adaptor.getOffsets());
366  Value ptrs =
367  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
368  adaptor.getBase(), ptr, adaptor.getIndices(), vType);
369 
370  // Replace with the scatter intrinsic.
371  rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
372  scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
373  rewriter.getI32IntegerAttr(align));
374  return success();
375  }
376 
377 private:
378  // If true, use the preferred alignment of the vector type.
379  // If false, use the preferred alignment of the element type
380  // of the memref. This flag is intended for use with hardware
381  // backends that require alignment of vector operations.
382  const bool useVectorAlignment;
383 };
384 
385 /// Conversion pattern for a vector.expandload.
386 class VectorExpandLoadOpConversion
387  : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
388 public:
390 
391  LogicalResult
392  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
393  ConversionPatternRewriter &rewriter) const override {
394  auto loc = expand->getLoc();
395  MemRefType memRefType = expand.getMemRefType();
396 
397  // Resolve address.
398  auto vtype = typeConverter->convertType(expand.getVectorType());
399  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
400  adaptor.getBase(), adaptor.getIndices());
401 
402  rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
403  expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
404  return success();
405  }
406 };
407 
408 /// Conversion pattern for a vector.compressstore.
409 class VectorCompressStoreOpConversion
410  : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
411 public:
413 
414  LogicalResult
415  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
416  ConversionPatternRewriter &rewriter) const override {
417  auto loc = compress->getLoc();
418  MemRefType memRefType = compress.getMemRefType();
419 
420  // Resolve address.
421  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
422  adaptor.getBase(), adaptor.getIndices());
423 
424  rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
425  compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
426  return success();
427  }
428 };
429 
430 /// Reduction neutral classes for overloading.
431 class ReductionNeutralZero {};
432 class ReductionNeutralIntOne {};
433 class ReductionNeutralFPOne {};
434 class ReductionNeutralAllOnes {};
435 class ReductionNeutralSIntMin {};
436 class ReductionNeutralUIntMin {};
437 class ReductionNeutralSIntMax {};
438 class ReductionNeutralUIntMax {};
439 class ReductionNeutralFPMin {};
440 class ReductionNeutralFPMax {};
441 
442 /// Create the reduction neutral zero value.
443 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
444  ConversionPatternRewriter &rewriter,
445  Location loc, Type llvmType) {
446  return LLVM::ConstantOp::create(rewriter, loc, llvmType,
447  rewriter.getZeroAttr(llvmType));
448 }
449 
450 /// Create the reduction neutral integer one value.
451 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
452  ConversionPatternRewriter &rewriter,
453  Location loc, Type llvmType) {
454  return LLVM::ConstantOp::create(rewriter, loc, llvmType,
455  rewriter.getIntegerAttr(llvmType, 1));
456 }
457 
458 /// Create the reduction neutral fp one value.
459 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
460  ConversionPatternRewriter &rewriter,
461  Location loc, Type llvmType) {
462  return LLVM::ConstantOp::create(rewriter, loc, llvmType,
463  rewriter.getFloatAttr(llvmType, 1.0));
464 }
465 
466 /// Create the reduction neutral all-ones value.
467 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
468  ConversionPatternRewriter &rewriter,
469  Location loc, Type llvmType) {
470  return LLVM::ConstantOp::create(
471  rewriter, loc, llvmType,
472  rewriter.getIntegerAttr(
473  llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
474 }
475 
476 /// Create the reduction neutral signed int minimum value.
477 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
478  ConversionPatternRewriter &rewriter,
479  Location loc, Type llvmType) {
480  return LLVM::ConstantOp::create(
481  rewriter, loc, llvmType,
482  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
483  llvmType.getIntOrFloatBitWidth())));
484 }
485 
486 /// Create the reduction neutral unsigned int minimum value.
487 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
488  ConversionPatternRewriter &rewriter,
489  Location loc, Type llvmType) {
490  return LLVM::ConstantOp::create(
491  rewriter, loc, llvmType,
492  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
493  llvmType.getIntOrFloatBitWidth())));
494 }
495 
496 /// Create the reduction neutral signed int maximum value.
497 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
498  ConversionPatternRewriter &rewriter,
499  Location loc, Type llvmType) {
500  return LLVM::ConstantOp::create(
501  rewriter, loc, llvmType,
502  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
503  llvmType.getIntOrFloatBitWidth())));
504 }
505 
506 /// Create the reduction neutral unsigned int maximum value.
507 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
508  ConversionPatternRewriter &rewriter,
509  Location loc, Type llvmType) {
510  return LLVM::ConstantOp::create(
511  rewriter, loc, llvmType,
512  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
513  llvmType.getIntOrFloatBitWidth())));
514 }
515 
516 /// Create the reduction neutral fp minimum value.
517 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
518  ConversionPatternRewriter &rewriter,
519  Location loc, Type llvmType) {
520  auto floatType = cast<FloatType>(llvmType);
521  return LLVM::ConstantOp::create(
522  rewriter, loc, llvmType,
523  rewriter.getFloatAttr(
524  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
525  /*Negative=*/false)));
526 }
527 
528 /// Create the reduction neutral fp maximum value.
529 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
530  ConversionPatternRewriter &rewriter,
531  Location loc, Type llvmType) {
532  auto floatType = cast<FloatType>(llvmType);
533  return LLVM::ConstantOp::create(
534  rewriter, loc, llvmType,
535  rewriter.getFloatAttr(
536  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
537  /*Negative=*/true)));
538 }
539 
540 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
541 /// returns a new accumulator value using `ReductionNeutral`.
542 template <class ReductionNeutral>
543 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
544  Location loc, Type llvmType,
545  Value accumulator) {
546  if (accumulator)
547  return accumulator;
548 
549  return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
550  llvmType);
551 }
552 
553 /// Creates a value with the 1-D vector shape provided in `llvmType`.
554 /// This is used as effective vector length by some intrinsics supporting
555 /// dynamic vector lengths at runtime.
556 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
557  Location loc, Type llvmType) {
558  VectorType vType = cast<VectorType>(llvmType);
559  auto vShape = vType.getShape();
560  assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
561 
562  Value baseVecLength = LLVM::ConstantOp::create(
563  rewriter, loc, rewriter.getI32Type(),
564  rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
565 
566  if (!vType.getScalableDims()[0])
567  return baseVecLength;
568 
569  // For a scalable vector type, create and return `vScale * baseVecLength`.
570  Value vScale = vector::VectorScaleOp::create(rewriter, loc);
571  vScale =
572  arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
573  Value scalableVecLength =
574  arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
575  return scalableVecLength;
576 }
577 
578 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
579 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
580 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
581 /// non-null.
582 template <class LLVMRedIntrinOp, class ScalarOp>
583 static Value createIntegerReductionArithmeticOpLowering(
584  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
585  Value vectorOperand, Value accumulator) {
586 
587  Value result =
588  LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
589 
590  if (accumulator)
591  result = ScalarOp::create(rewriter, loc, accumulator, result);
592  return result;
593 }
594 
595 /// Helper method to lower a `vector.reduction` operation that performs
596 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
597 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
598 /// the accumulator value if non-null.
599 template <class LLVMRedIntrinOp>
600 static Value createIntegerReductionComparisonOpLowering(
601  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
602  Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
603  Value result =
604  LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
605  if (accumulator) {
606  Value cmp =
607  LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
608  result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
609  }
610  return result;
611 }
612 
613 namespace {
614 template <typename Source>
615 struct VectorToScalarMapper;
616 template <>
617 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
618  using Type = LLVM::MaximumOp;
619 };
620 template <>
621 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
622  using Type = LLVM::MinimumOp;
623 };
624 template <>
625 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
626  using Type = LLVM::MaxNumOp;
627 };
628 template <>
629 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
630  using Type = LLVM::MinNumOp;
631 };
632 } // namespace
633 
634 template <class LLVMRedIntrinOp>
635 static Value createFPReductionComparisonOpLowering(
636  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
637  Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
638  Value result =
639  LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
640 
641  if (accumulator) {
642  result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
643  rewriter, loc, result, accumulator);
644  }
645 
646  return result;
647 }
648 
649 /// Reduction neutral classes for overloading
650 class MaskNeutralFMaximum {};
651 class MaskNeutralFMinimum {};
652 
653 /// Get the mask neutral floating point maximum value
654 static llvm::APFloat
655 getMaskNeutralValue(MaskNeutralFMaximum,
656  const llvm::fltSemantics &floatSemantics) {
657  return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
658 }
659 /// Get the mask neutral floating point minimum value
660 static llvm::APFloat
661 getMaskNeutralValue(MaskNeutralFMinimum,
662  const llvm::fltSemantics &floatSemantics) {
663  return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
664 }
665 
666 /// Create the mask neutral floating point MLIR vector constant
667 template <typename MaskNeutral>
668 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
669  Location loc, Type llvmType,
670  Type vectorType) {
671  const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
672  auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
673  auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
674  return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
675 }
676 
677 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
678 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
679 /// `fmaximum`/`fminimum`.
680 /// More information: https://github.com/llvm/llvm-project/issues/64940
681 template <class LLVMRedIntrinOp, class MaskNeutral>
682 static Value
683 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
684  Location loc, Type llvmType,
685  Value vectorOperand, Value accumulator,
686  Value mask, LLVM::FastmathFlagsAttr fmf) {
687  const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
688  rewriter, loc, llvmType, vectorOperand.getType());
689  const Value selectedVectorByMask = LLVM::SelectOp::create(
690  rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
691  return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
692  rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
693 }
694 
695 template <class LLVMRedIntrinOp, class ReductionNeutral>
696 static Value
697 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
698  Type llvmType, Value vectorOperand,
699  Value accumulator, LLVM::FastmathFlagsAttr fmf) {
700  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701  llvmType, accumulator);
702  return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
703  /*startValue=*/accumulator, vectorOperand,
704  fmf);
705 }
706 
707 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
708 /// that requires a start value. This start value format spans across fp
709 /// reductions without mask and all the masked reduction intrinsics.
710 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
711 static Value
712 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
713  Location loc, Type llvmType,
714  Value vectorOperand, Value accumulator) {
715  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
716  llvmType, accumulator);
717  return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
718  /*startValue=*/accumulator, vectorOperand);
719 }
720 
721 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
722 static Value lowerPredicatedReductionWithStartValue(
723  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
724  Value vectorOperand, Value accumulator, Value mask) {
725  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
726  llvmType, accumulator);
727  Value vectorLength =
728  createVectorLengthValue(rewriter, loc, vectorOperand.getType());
729  return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
730  /*startValue=*/accumulator, vectorOperand,
731  mask, vectorLength);
732 }
733 
734 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
735  class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
736 static Value lowerPredicatedReductionWithStartValue(
737  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
738  Value vectorOperand, Value accumulator, Value mask) {
739  if (llvmType.isIntOrIndex())
740  return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
741  IntReductionNeutral>(
742  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
743 
744  // FP dispatch.
745  return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
746  FPReductionNeutral>(
747  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
748 }
749 
750 /// Conversion pattern for all vector reductions.
751 class VectorReductionOpConversion
752  : public ConvertOpToLLVMPattern<vector::ReductionOp> {
753 public:
754  explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
755  bool reassociateFPRed)
756  : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
757  reassociateFPReductions(reassociateFPRed) {}
758 
759  LogicalResult
760  matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
761  ConversionPatternRewriter &rewriter) const override {
762  auto kind = reductionOp.getKind();
763  Type eltType = reductionOp.getDest().getType();
764  Type llvmType = typeConverter->convertType(eltType);
765  Value operand = adaptor.getVector();
766  Value acc = adaptor.getAcc();
767  Location loc = reductionOp.getLoc();
768 
769  if (eltType.isIntOrIndex()) {
770  // Integer reductions: add/mul/min/max/and/or/xor.
771  Value result;
772  switch (kind) {
773  case vector::CombiningKind::ADD:
774  result =
775  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
776  LLVM::AddOp>(
777  rewriter, loc, llvmType, operand, acc);
778  break;
779  case vector::CombiningKind::MUL:
780  result =
781  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
782  LLVM::MulOp>(
783  rewriter, loc, llvmType, operand, acc);
784  break;
786  result = createIntegerReductionComparisonOpLowering<
787  LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
788  LLVM::ICmpPredicate::ule);
789  break;
790  case vector::CombiningKind::MINSI:
791  result = createIntegerReductionComparisonOpLowering<
792  LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
793  LLVM::ICmpPredicate::sle);
794  break;
795  case vector::CombiningKind::MAXUI:
796  result = createIntegerReductionComparisonOpLowering<
797  LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
798  LLVM::ICmpPredicate::uge);
799  break;
800  case vector::CombiningKind::MAXSI:
801  result = createIntegerReductionComparisonOpLowering<
802  LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
803  LLVM::ICmpPredicate::sge);
804  break;
805  case vector::CombiningKind::AND:
806  result =
807  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
808  LLVM::AndOp>(
809  rewriter, loc, llvmType, operand, acc);
810  break;
811  case vector::CombiningKind::OR:
812  result =
813  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
814  LLVM::OrOp>(
815  rewriter, loc, llvmType, operand, acc);
816  break;
817  case vector::CombiningKind::XOR:
818  result =
819  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
820  LLVM::XOrOp>(
821  rewriter, loc, llvmType, operand, acc);
822  break;
823  default:
824  return failure();
825  }
826  rewriter.replaceOp(reductionOp, result);
827 
828  return success();
829  }
830 
831  if (!isa<FloatType>(eltType))
832  return failure();
833 
834  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
835  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
836  reductionOp.getContext(),
837  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
839  reductionOp.getContext(),
840  fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
841  : LLVM::FastmathFlags::none));
842 
843  // Floating-point reductions: add/mul/min/max
844  Value result;
845  if (kind == vector::CombiningKind::ADD) {
846  result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
847  ReductionNeutralZero>(
848  rewriter, loc, llvmType, operand, acc, fmf);
849  } else if (kind == vector::CombiningKind::MUL) {
850  result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
851  ReductionNeutralFPOne>(
852  rewriter, loc, llvmType, operand, acc, fmf);
853  } else if (kind == vector::CombiningKind::MINIMUMF) {
854  result =
855  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
856  rewriter, loc, llvmType, operand, acc, fmf);
857  } else if (kind == vector::CombiningKind::MAXIMUMF) {
858  result =
859  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
860  rewriter, loc, llvmType, operand, acc, fmf);
861  } else if (kind == vector::CombiningKind::MINNUMF) {
862  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
863  rewriter, loc, llvmType, operand, acc, fmf);
864  } else if (kind == vector::CombiningKind::MAXNUMF) {
865  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
866  rewriter, loc, llvmType, operand, acc, fmf);
867  } else {
868  return failure();
869  }
870 
871  rewriter.replaceOp(reductionOp, result);
872  return success();
873  }
874 
875 private:
876  const bool reassociateFPReductions;
877 };
878 
879 /// Base class to convert a `vector.mask` operation while matching traits
880 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
881 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
882 /// method performs a second match against the maskable operation `MaskedOp`.
883 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
884 /// implemented by the concrete conversion classes. This method can match
885 /// against specific traits of the `vector.mask` and the maskable operation. It
886 /// must replace the `vector.mask` operation.
887 template <class MaskedOp>
888 class VectorMaskOpConversionBase
889  : public ConvertOpToLLVMPattern<vector::MaskOp> {
890 public:
892 
893  LogicalResult
894  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
895  ConversionPatternRewriter &rewriter) const final {
896  // Match against the maskable operation kind.
897  auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
898  if (!maskedOp)
899  return failure();
900  return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
901  }
902 
903 protected:
904  virtual LogicalResult
905  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
906  vector::MaskableOpInterface maskableOp,
907  ConversionPatternRewriter &rewriter) const = 0;
908 };
909 
910 class MaskedReductionOpConversion
911  : public VectorMaskOpConversionBase<vector::ReductionOp> {
912 
913 public:
914  using VectorMaskOpConversionBase<
915  vector::ReductionOp>::VectorMaskOpConversionBase;
916 
917  LogicalResult matchAndRewriteMaskableOp(
918  vector::MaskOp maskOp, MaskableOpInterface maskableOp,
919  ConversionPatternRewriter &rewriter) const override {
920  auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
921  auto kind = reductionOp.getKind();
922  Type eltType = reductionOp.getDest().getType();
923  Type llvmType = typeConverter->convertType(eltType);
924  Value operand = reductionOp.getVector();
925  Value acc = reductionOp.getAcc();
926  Location loc = reductionOp.getLoc();
927 
928  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
929  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
930  reductionOp.getContext(),
931  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
932 
933  Value result;
934  switch (kind) {
935  case vector::CombiningKind::ADD:
936  result = lowerPredicatedReductionWithStartValue<
937  LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
938  ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
939  maskOp.getMask());
940  break;
941  case vector::CombiningKind::MUL:
942  result = lowerPredicatedReductionWithStartValue<
943  LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
944  ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
945  maskOp.getMask());
946  break;
948  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
949  ReductionNeutralUIntMax>(
950  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
951  break;
952  case vector::CombiningKind::MINSI:
953  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
954  ReductionNeutralSIntMax>(
955  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
956  break;
957  case vector::CombiningKind::MAXUI:
958  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
959  ReductionNeutralUIntMin>(
960  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
961  break;
962  case vector::CombiningKind::MAXSI:
963  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
964  ReductionNeutralSIntMin>(
965  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
966  break;
967  case vector::CombiningKind::AND:
968  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
969  ReductionNeutralAllOnes>(
970  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
971  break;
972  case vector::CombiningKind::OR:
973  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
974  ReductionNeutralZero>(
975  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
976  break;
977  case vector::CombiningKind::XOR:
978  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
979  ReductionNeutralZero>(
980  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
981  break;
982  case vector::CombiningKind::MINNUMF:
983  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
984  ReductionNeutralFPMax>(
985  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
986  break;
987  case vector::CombiningKind::MAXNUMF:
988  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
989  ReductionNeutralFPMin>(
990  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
991  break;
992  case CombiningKind::MAXIMUMF:
993  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
994  MaskNeutralFMaximum>(
995  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
996  break;
997  case CombiningKind::MINIMUMF:
998  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
999  MaskNeutralFMinimum>(
1000  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1001  break;
1002  }
1003 
1004  // Replace `vector.mask` operation altogether.
1005  rewriter.replaceOp(maskOp, result);
1006  return success();
1007  }
1008 };
1009 
1010 class VectorShuffleOpConversion
1011  : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1012 public:
1014 
1015  LogicalResult
1016  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1017  ConversionPatternRewriter &rewriter) const override {
1018  auto loc = shuffleOp->getLoc();
1019  auto v1Type = shuffleOp.getV1VectorType();
1020  auto v2Type = shuffleOp.getV2VectorType();
1021  auto vectorType = shuffleOp.getResultVectorType();
1022  Type llvmType = typeConverter->convertType(vectorType);
1023  ArrayRef<int64_t> mask = shuffleOp.getMask();
1024 
1025  // Bail if result type cannot be lowered.
1026  if (!llvmType)
1027  return failure();
1028 
1029  // Get rank and dimension sizes.
1030  int64_t rank = vectorType.getRank();
1031 #ifndef NDEBUG
1032  bool wellFormed0DCase =
1033  v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1034  bool wellFormedNDCase =
1035  v1Type.getRank() == rank && v2Type.getRank() == rank;
1036  assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1037 #endif
1038 
1039  // For rank 0 and 1, where both operands have *exactly* the same vector
1040  // type, there is direct shuffle support in LLVM. Use it!
1041  if (rank <= 1 && v1Type == v2Type) {
1042  Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1043  rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1044  llvm::to_vector_of<int32_t>(mask));
1045  rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1046  return success();
1047  }
1048 
1049  // For all other cases, insert the individual values individually.
1050  int64_t v1Dim = v1Type.getDimSize(0);
1051  Type eltType;
1052  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1053  eltType = arrayType.getElementType();
1054  else
1055  eltType = cast<VectorType>(llvmType).getElementType();
1056  Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1057  int64_t insPos = 0;
1058  for (int64_t extPos : mask) {
1059  Value value = adaptor.getV1();
1060  if (extPos >= v1Dim) {
1061  extPos -= v1Dim;
1062  value = adaptor.getV2();
1063  }
1064  Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1065  eltType, rank, extPos);
1066  insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1067  llvmType, rank, insPos++);
1068  }
1069  rewriter.replaceOp(shuffleOp, insert);
1070  return success();
1071  }
1072 };
1073 
1074 class VectorExtractOpConversion
1075  : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1076 public:
1078 
1079  LogicalResult
1080  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1081  ConversionPatternRewriter &rewriter) const override {
1082  auto loc = extractOp->getLoc();
1083  auto resultType = extractOp.getResult().getType();
1084  auto llvmResultType = typeConverter->convertType(resultType);
1085  // Bail if result type cannot be lowered.
1086  if (!llvmResultType)
1087  return failure();
1088 
1090  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1091 
1092  // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1093  // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1094  // from a N-d vector extract to a nested aggregate vector extract in two
1095  // steps:
1096  // - Extract a member from the nested aggregate. The result can be
1097  // a lower rank nested aggregate or a vector (1-D). This is done using
1098  // `llvm.extractvalue`.
1099  // - Extract a scalar out of the vector if needed. This is done using
1100  // `llvm.extractelement`.
1101 
1102  // Determine if we need to extract a member out of the aggregate. We
1103  // always need to extract a member if the input rank >= 2.
1104  bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1105  // Determine if we need to extract a scalar as the result. We extract
1106  // a scalar if the extract is full rank, i.e., the number of indices is
1107  // equal to source vector rank.
1108  bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1109  extractOp.getSourceVectorType().getRank();
1110 
1111  // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1112  // need to add a position for this change.
1113  if (extractOp.getSourceVectorType().getRank() == 0) {
1114  Type idxType = typeConverter->convertType(rewriter.getIndexType());
1115  positionVec.push_back(rewriter.getZeroAttr(idxType));
1116  }
1117 
1118  Value extracted = adaptor.getVector();
1119  if (extractsAggregate) {
1120  ArrayRef<OpFoldResult> position(positionVec);
1121  if (extractsScalar) {
1122  // If we are extracting a scalar from the extracted member, we drop
1123  // the last index, which will be used to extract the scalar out of the
1124  // vector.
1125  position = position.drop_back();
1126  }
1127  // llvm.extractvalue does not support dynamic dimensions.
1128  if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1129  return failure();
1130  }
1131  extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1132  getAsIntegers(position));
1133  }
1134 
1135  if (extractsScalar) {
1136  extracted = LLVM::ExtractElementOp::create(
1137  rewriter, loc, extracted,
1138  getAsLLVMValue(rewriter, loc, positionVec.back()));
1139  }
1140 
1141  rewriter.replaceOp(extractOp, extracted);
1142  return success();
1143  }
1144 };
1145 
1146 /// Conversion pattern that turns a vector.fma on a 1-D vector
1147 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1148 /// This does not match vectors of n >= 2 rank.
1149 ///
1150 /// Example:
1151 /// ```
1152 /// vector.fma %a, %a, %a : vector<8xf32>
1153 /// ```
1154 /// is converted to:
1155 /// ```
1156 /// llvm.intr.fmuladd %va, %va, %va:
1157 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1158 /// -> !llvm."<8 x f32>">
1159 /// ```
1160 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1161 public:
1163 
1164  LogicalResult
1165  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1166  ConversionPatternRewriter &rewriter) const override {
1167  VectorType vType = fmaOp.getVectorType();
1168  if (vType.getRank() > 1)
1169  return failure();
1170 
1171  rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1172  fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1173  return success();
1174  }
1175 };
1176 
1177 class VectorInsertOpConversion
1178  : public ConvertOpToLLVMPattern<vector::InsertOp> {
1179 public:
1181 
1182  LogicalResult
1183  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1184  ConversionPatternRewriter &rewriter) const override {
1185  auto loc = insertOp->getLoc();
1186  auto destVectorType = insertOp.getDestVectorType();
1187  auto llvmResultType = typeConverter->convertType(destVectorType);
1188  // Bail if result type cannot be lowered.
1189  if (!llvmResultType)
1190  return failure();
1191 
1193  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1194 
1195  // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1196  // its explanatory comment about how N-D vectors are converted as nested
1197  // aggregates (llvm.array's) of 1D vectors.
1198  //
1199  // The innermost dimension of the destination vector, when converted to a
1200  // nested aggregate form, will always be a 1D vector.
1201  //
1202  // * If the insertion is happening into the innermost dimension of the
1203  // destination vector:
1204  // - If the destination is a nested aggregate, extract a 1D vector out of
1205  // the aggregate. This can be done using llvm.extractvalue. The
1206  // destination is now guaranteed to be a 1D vector, to which we are
1207  // inserting.
1208  // - Do the insertion into the 1D destination vector, and make the result
1209  // the new source nested aggregate. This can be done using
1210  // llvm.insertelement.
1211  // * Insert the source nested aggregate into the destination nested
1212  // aggregate.
1213 
1214  // Determine if we need to extract/insert a 1D vector out of the aggregate.
1215  bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1216  // Determine if we need to insert a scalar into the 1D vector.
1217  bool insertIntoInnermostDim =
1218  static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1219 
1220  ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1221  positionVec.begin(),
1222  insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1223  OpFoldResult positionOfScalarWithin1DVector;
1224  if (destVectorType.getRank() == 0) {
1225  // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1226  // need to create a 0 here as the position into the 1D vector.
1227  Type idxType = typeConverter->convertType(rewriter.getIndexType());
1228  positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1229  } else if (insertIntoInnermostDim) {
1230  positionOfScalarWithin1DVector = positionVec.back();
1231  }
1232 
1233  // We are going to mutate this 1D vector until it is either the final
1234  // result (in the non-aggregate case) or the value that needs to be
1235  // inserted into the aggregate result.
1236  Value sourceAggregate = adaptor.getValueToStore();
1237  if (insertIntoInnermostDim) {
1238  // Scalar-into-1D-vector case, so we know we will have to create a
1239  // InsertElementOp. The question is into what destination.
1240  if (isNestedAggregate) {
1241  // Aggregate case: the destination for the InsertElementOp needs to be
1242  // extracted from the aggregate.
1243  if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1244  llvm::IsaPred<Attribute>)) {
1245  // llvm.extractvalue does not support dynamic dimensions.
1246  return failure();
1247  }
1248  sourceAggregate = LLVM::ExtractValueOp::create(
1249  rewriter, loc, adaptor.getDest(),
1250  getAsIntegers(positionOf1DVectorWithinAggregate));
1251  } else {
1252  // No-aggregate case. The destination for the InsertElementOp is just
1253  // the insertOp's destination.
1254  sourceAggregate = adaptor.getDest();
1255  }
1256  // Insert the scalar into the 1D vector.
1257  sourceAggregate = LLVM::InsertElementOp::create(
1258  rewriter, loc, sourceAggregate.getType(), sourceAggregate,
1259  adaptor.getValueToStore(),
1260  getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1261  }
1262 
1263  Value result = sourceAggregate;
1264  if (isNestedAggregate) {
1265  result = LLVM::InsertValueOp::create(
1266  rewriter, loc, adaptor.getDest(), sourceAggregate,
1267  getAsIntegers(positionOf1DVectorWithinAggregate));
1268  }
1269 
1270  rewriter.replaceOp(insertOp, result);
1271  return success();
1272  }
1273 };
1274 
1275 /// Lower vector.scalable.insert ops to LLVM vector.insert
1276 struct VectorScalableInsertOpLowering
1277  : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1278  using ConvertOpToLLVMPattern<
1279  vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1280 
1281  LogicalResult
1282  matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1283  ConversionPatternRewriter &rewriter) const override {
1284  rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1285  insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1286  return success();
1287  }
1288 };
1289 
1290 /// Lower vector.scalable.extract ops to LLVM vector.extract
1291 struct VectorScalableExtractOpLowering
1292  : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1293  using ConvertOpToLLVMPattern<
1294  vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1295 
1296  LogicalResult
1297  matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1298  ConversionPatternRewriter &rewriter) const override {
1299  rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1300  extOp, typeConverter->convertType(extOp.getResultVectorType()),
1301  adaptor.getSource(), adaptor.getPos());
1302  return success();
1303  }
1304 };
1305 
1306 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1307 ///
1308 /// Example:
1309 /// ```
1310 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1311 /// ```
1312 /// is rewritten into:
1313 /// ```
1314 /// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
1315 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1316 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1317 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1318 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1319 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1320 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1321 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1322 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1323 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1324 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1325 /// // %r3 holds the final value.
1326 /// ```
1327 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1328 public:
1330 
1331  void initialize() {
1332  // This pattern recursively unpacks one dimension at a time. The recursion
1333  // bounded as the rank is strictly decreasing.
1334  setHasBoundedRewriteRecursion();
1335  }
1336 
1337  LogicalResult matchAndRewrite(FMAOp op,
1338  PatternRewriter &rewriter) const override {
1339  auto vType = op.getVectorType();
1340  if (vType.getRank() < 2)
1341  return failure();
1342 
1343  auto loc = op.getLoc();
1344  auto elemType = vType.getElementType();
1345  Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1346  rewriter.getZeroAttr(elemType));
1347  Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1348  for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1349  Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1350  Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1351  Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1352  Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1353  desc = InsertOp::create(rewriter, loc, fma, desc, i);
1354  }
1355  rewriter.replaceOp(op, desc);
1356  return success();
1357  }
1358 };
1359 
1360 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1361 /// static layout.
1362 static std::optional<SmallVector<int64_t, 4>>
1363 computeContiguousStrides(MemRefType memRefType) {
1364  int64_t offset;
1365  SmallVector<int64_t, 4> strides;
1366  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1367  return std::nullopt;
1368  if (!strides.empty() && strides.back() != 1)
1369  return std::nullopt;
1370  // If no layout or identity layout, this is contiguous by definition.
1371  if (memRefType.getLayout().isIdentity())
1372  return strides;
1373 
1374  // Otherwise, we must determine contiguity form shapes. This can only ever
1375  // work in static cases because MemRefType is underspecified to represent
1376  // contiguous dynamic shapes in other ways than with just empty/identity
1377  // layout.
1378  auto sizes = memRefType.getShape();
1379  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1380  if (ShapedType::isDynamic(sizes[index + 1]) ||
1381  ShapedType::isDynamic(strides[index]) ||
1382  ShapedType::isDynamic(strides[index + 1]))
1383  return std::nullopt;
1384  if (strides[index] != strides[index + 1] * sizes[index + 1])
1385  return std::nullopt;
1386  }
1387  return strides;
1388 }
1389 
1390 class VectorTypeCastOpConversion
1391  : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1392 public:
1394 
1395  LogicalResult
1396  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1397  ConversionPatternRewriter &rewriter) const override {
1398  auto loc = castOp->getLoc();
1399  MemRefType sourceMemRefType =
1400  cast<MemRefType>(castOp.getOperand().getType());
1401  MemRefType targetMemRefType = castOp.getType();
1402 
1403  // Only static shape casts supported atm.
1404  if (!sourceMemRefType.hasStaticShape() ||
1405  !targetMemRefType.hasStaticShape())
1406  return failure();
1407 
1408  auto llvmSourceDescriptorTy =
1409  dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1410  if (!llvmSourceDescriptorTy)
1411  return failure();
1412  MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1413 
1414  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1415  typeConverter->convertType(targetMemRefType));
1416  if (!llvmTargetDescriptorTy)
1417  return failure();
1418 
1419  // Only contiguous source buffers supported atm.
1420  auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1421  if (!sourceStrides)
1422  return failure();
1423  auto targetStrides = computeContiguousStrides(targetMemRefType);
1424  if (!targetStrides)
1425  return failure();
1426  // Only support static strides for now, regardless of contiguity.
1427  if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1428  return failure();
1429 
1430  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1431 
1432  // Create descriptor.
1433  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1434  // Set allocated ptr.
1435  Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1436  desc.setAllocatedPtr(rewriter, loc, allocated);
1437 
1438  // Set aligned ptr.
1439  Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1440  desc.setAlignedPtr(rewriter, loc, ptr);
1441  // Fill offset 0.
1442  auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1443  auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1444  desc.setOffset(rewriter, loc, zero);
1445 
1446  // Fill size and stride descriptors in memref.
1447  for (const auto &indexedSize :
1448  llvm::enumerate(targetMemRefType.getShape())) {
1449  int64_t index = indexedSize.index();
1450  auto sizeAttr =
1451  rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1452  auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1453  desc.setSize(rewriter, loc, index, size);
1454  auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1455  (*targetStrides)[index]);
1456  auto stride =
1457  LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1458  desc.setStride(rewriter, loc, index, stride);
1459  }
1460 
1461  rewriter.replaceOp(castOp, {desc});
1462  return success();
1463  }
1464 };
1465 
1466 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1467 /// Non-scalable versions of this operation are handled in Vector Transforms.
1468 class VectorCreateMaskOpConversion
1469  : public OpConversionPattern<vector::CreateMaskOp> {
1470 public:
1471  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1472  bool enableIndexOpt)
1473  : OpConversionPattern<vector::CreateMaskOp>(context),
1474  force32BitVectorIndices(enableIndexOpt) {}
1475 
1476  LogicalResult
1477  matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1478  ConversionPatternRewriter &rewriter) const override {
1479  auto dstType = op.getType();
1480  if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1481  return failure();
1482  IntegerType idxType =
1483  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1484  auto loc = op->getLoc();
1485  Value indices = LLVM::StepVectorOp::create(
1486  rewriter, loc,
1487  LLVM::getVectorType(idxType, dstType.getShape()[0],
1488  /*isScalable=*/true));
1489  auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1490  adaptor.getOperands()[0]);
1491  Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1492  Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1493  indices, bounds);
1494  rewriter.replaceOp(op, comp);
1495  return success();
1496  }
1497 
1498 private:
1499  const bool force32BitVectorIndices;
1500 };
1501 
1502 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1503  SymbolTableCollection *symbolTables = nullptr;
1504 
1505 public:
1506  explicit VectorPrintOpConversion(
1507  const LLVMTypeConverter &typeConverter,
1508  SymbolTableCollection *symbolTables = nullptr)
1509  : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1510  symbolTables(symbolTables) {}
1511 
1512  // Lowering implementation that relies on a small runtime support library,
1513  // which only needs to provide a few printing methods (single value for all
1514  // data types, opening/closing bracket, comma, newline). The lowering splits
1515  // the vector into elementary printing operations. The advantage of this
1516  // approach is that the library can remain unaware of all low-level
1517  // implementation details of vectors while still supporting output of any
1518  // shaped and dimensioned vector.
1519  //
1520  // Note: This lowering only handles scalars, n-D vectors are broken into
1521  // printing scalars in loops in VectorToSCF.
1522  //
1523  // TODO: rely solely on libc in future? something else?
1524  //
1525  LogicalResult
1526  matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1527  ConversionPatternRewriter &rewriter) const override {
1528  auto parent = printOp->getParentOfType<ModuleOp>();
1529  if (!parent)
1530  return failure();
1531 
1532  auto loc = printOp->getLoc();
1533 
1534  if (auto value = adaptor.getSource()) {
1535  Type printType = printOp.getPrintType();
1536  if (isa<VectorType>(printType)) {
1537  // Vectors should be broken into elementary print ops in VectorToSCF.
1538  return failure();
1539  }
1540  if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1541  return failure();
1542  }
1543 
1544  auto punct = printOp.getPunctuation();
1545  if (auto stringLiteral = printOp.getStringLiteral()) {
1546  auto createResult =
1547  LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1548  *stringLiteral, *getTypeConverter(),
1549  /*addNewline=*/false);
1550  if (createResult.failed())
1551  return failure();
1552 
1553  } else if (punct != PrintPunctuation::NoPunctuation) {
1554  FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1555  switch (punct) {
1556  case PrintPunctuation::Close:
1557  return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
1558  symbolTables);
1559  case PrintPunctuation::Open:
1560  return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
1561  symbolTables);
1562  case PrintPunctuation::Comma:
1563  return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
1564  symbolTables);
1565  case PrintPunctuation::NewLine:
1566  return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
1567  symbolTables);
1568  default:
1569  llvm_unreachable("unexpected punctuation");
1570  }
1571  }();
1572  if (failed(op))
1573  return failure();
1574  emitCall(rewriter, printOp->getLoc(), op.value());
1575  }
1576 
1577  rewriter.eraseOp(printOp);
1578  return success();
1579  }
1580 
1581 private:
1582  enum class PrintConversion {
1583  // clang-format off
1584  None,
1585  ZeroExt64,
1586  SignExt64,
1587  Bitcast16
1588  // clang-format on
1589  };
1590 
1591  LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1592  ModuleOp parent, Location loc, Type printType,
1593  Value value) const {
1594  if (typeConverter->convertType(printType) == nullptr)
1595  return failure();
1596 
1597  // Make sure element type has runtime support.
1598  PrintConversion conversion = PrintConversion::None;
1599  FailureOr<Operation *> printer;
1600  if (printType.isF32()) {
1601  printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
1602  } else if (printType.isF64()) {
1603  printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
1604  } else if (printType.isF16()) {
1605  conversion = PrintConversion::Bitcast16; // bits!
1606  printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
1607  } else if (printType.isBF16()) {
1608  conversion = PrintConversion::Bitcast16; // bits!
1609  printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
1610  } else if (printType.isIndex()) {
1611  printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1612  } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1613  // Integers need a zero or sign extension on the operand
1614  // (depending on the source type) as well as a signed or
1615  // unsigned print method. Up to 64-bit is supported.
1616  unsigned width = intTy.getWidth();
1617  if (intTy.isUnsigned()) {
1618  if (width <= 64) {
1619  if (width < 64)
1620  conversion = PrintConversion::ZeroExt64;
1621  printer =
1622  LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1623  } else {
1624  return failure();
1625  }
1626  } else {
1627  assert(intTy.isSignless() || intTy.isSigned());
1628  if (width <= 64) {
1629  // Note that we *always* zero extend booleans (1-bit integers),
1630  // so that true/false is printed as 1/0 rather than -1/0.
1631  if (width == 1)
1632  conversion = PrintConversion::ZeroExt64;
1633  else if (width < 64)
1634  conversion = PrintConversion::SignExt64;
1635  printer =
1636  LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
1637  } else {
1638  return failure();
1639  }
1640  }
1641  } else {
1642  return failure();
1643  }
1644  if (failed(printer))
1645  return failure();
1646 
1647  switch (conversion) {
1648  case PrintConversion::ZeroExt64:
1649  value = arith::ExtUIOp::create(
1650  rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1651  break;
1652  case PrintConversion::SignExt64:
1653  value = arith::ExtSIOp::create(
1654  rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1655  break;
1656  case PrintConversion::Bitcast16:
1657  value = LLVM::BitcastOp::create(
1658  rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1659  break;
1660  case PrintConversion::None:
1661  break;
1662  }
1663  emitCall(rewriter, loc, printer.value(), value);
1664  return success();
1665  }
1666 
1667  // Helper to emit a call.
1668  static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1669  Operation *ref, ValueRange params = ValueRange()) {
1670  LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref),
1671  params);
1672  }
1673 };
1674 
1675 /// A broadcast of a scalar is lowered to an insertelement + a shufflevector
1676 /// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1677 /// pattern, the higher rank cases are handled by another pattern.
1678 struct VectorBroadcastScalarToLowRankLowering
1679  : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1681 
1682  LogicalResult
1683  matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
1684  ConversionPatternRewriter &rewriter) const override {
1685  if (isa<VectorType>(broadcast.getSourceType()))
1686  return rewriter.notifyMatchFailure(
1687  broadcast, "broadcast from vector type not handled");
1688 
1689  VectorType resultType = broadcast.getType();
1690  if (resultType.getRank() > 1)
1691  return rewriter.notifyMatchFailure(broadcast,
1692  "broadcast to 2+-d handled elsewhere");
1693 
1694  // First insert it into a poison vector so we can shuffle it.
1695  auto vectorType = typeConverter->convertType(broadcast.getType());
1696  Value poison =
1697  LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType);
1698  auto zero = LLVM::ConstantOp::create(
1699  rewriter, broadcast.getLoc(),
1700  typeConverter->convertType(rewriter.getIntegerType(32)),
1701  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1702 
1703  // For 0-d vector, we simply do `insertelement`.
1704  if (resultType.getRank() == 0) {
1705  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1706  broadcast, vectorType, poison, adaptor.getSource(), zero);
1707  return success();
1708  }
1709 
1710  // For 1-d vector, we additionally do a `vectorshuffle`.
1711  auto v =
1712  LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
1713  poison, adaptor.getSource(), zero);
1714 
1715  int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
1716  SmallVector<int32_t> zeroValues(width, 0);
1717 
1718  // Shuffle the value across the desired number of elements.
1719  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
1720  zeroValues);
1721  return success();
1722  }
1723 };
1724 
1725 /// The broadcast of a scalar is lowered to an insertelement + a shufflevector
1726 /// operation. Only broadcasts to 2+-d vector result types are lowered by this
1727 /// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1728 /// are not converted to LLVM, only broadcasts from scalars are.
1729 struct VectorBroadcastScalarToNdLowering
1730  : public ConvertOpToLLVMPattern<BroadcastOp> {
1732 
1733  LogicalResult
1734  matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
1735  ConversionPatternRewriter &rewriter) const override {
1736  if (isa<VectorType>(broadcast.getSourceType()))
1737  return rewriter.notifyMatchFailure(
1738  broadcast, "broadcast from vector type not handled");
1739 
1740  VectorType resultType = broadcast.getType();
1741  if (resultType.getRank() <= 1)
1742  return rewriter.notifyMatchFailure(
1743  broadcast, "broadcast to 1-d or 0-d handled elsewhere");
1744 
1745  // First insert it into an undef vector so we can shuffle it.
1746  auto loc = broadcast.getLoc();
1747  auto vectorTypeInfo =
1748  LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1749  auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1750  auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1751  if (!llvmNDVectorTy || !llvm1DVectorTy)
1752  return failure();
1753 
1754  // Construct returned value.
1755  Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1756 
1757  // Construct a 1-D vector with the broadcasted value that we insert in all
1758  // the places within the returned descriptor.
1759  Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1760  auto zero = LLVM::ConstantOp::create(
1761  rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1762  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1763  Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1764  vdesc, adaptor.getSource(), zero);
1765 
1766  // Shuffle the value across the desired number of elements.
1767  int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1768  SmallVector<int32_t> zeroValues(width, 0);
1769  v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1770 
1771  // Iterate of linear index, convert to coords space and insert broadcasted
1772  // 1-D vector in each position.
1773  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1774  desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1775  });
1776  rewriter.replaceOp(broadcast, desc);
1777  return success();
1778  }
1779 };
1780 
1781 /// Conversion pattern for a `vector.interleave`.
1782 /// This supports fixed-sized vectors and scalable vectors.
1783 struct VectorInterleaveOpLowering
1784  : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1786 
1787  LogicalResult
1788  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1789  ConversionPatternRewriter &rewriter) const override {
1790  VectorType resultType = interleaveOp.getResultVectorType();
1791  // n-D interleaves should have been lowered already.
1792  if (resultType.getRank() != 1)
1793  return rewriter.notifyMatchFailure(interleaveOp,
1794  "InterleaveOp not rank 1");
1795  // If the result is rank 1, then this directly maps to LLVM.
1796  if (resultType.isScalable()) {
1797  rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1798  interleaveOp, typeConverter->convertType(resultType),
1799  adaptor.getLhs(), adaptor.getRhs());
1800  return success();
1801  }
1802  // Lower fixed-size interleaves to a shufflevector. While the
1803  // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1804  // langref still recommends fixed-vectors use shufflevector, see:
1805  // https://llvm.org/docs/LangRef.html#id876.
1806  int64_t resultVectorSize = resultType.getNumElements();
1807  SmallVector<int32_t> interleaveShuffleMask;
1808  interleaveShuffleMask.reserve(resultVectorSize);
1809  for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1810  interleaveShuffleMask.push_back(i);
1811  interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1812  }
1813  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1814  interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1815  interleaveShuffleMask);
1816  return success();
1817  }
1818 };
1819 
1820 /// Conversion pattern for a `vector.deinterleave`.
1821 /// This supports fixed-sized vectors and scalable vectors.
1822 struct VectorDeinterleaveOpLowering
1823  : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1825 
1826  LogicalResult
1827  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1828  ConversionPatternRewriter &rewriter) const override {
1829  VectorType resultType = deinterleaveOp.getResultVectorType();
1830  VectorType sourceType = deinterleaveOp.getSourceVectorType();
1831  auto loc = deinterleaveOp.getLoc();
1832 
1833  // Note: n-D deinterleave operations should be lowered to the 1-D before
1834  // converting to LLVM.
1835  if (resultType.getRank() != 1)
1836  return rewriter.notifyMatchFailure(deinterleaveOp,
1837  "DeinterleaveOp not rank 1");
1838 
1839  if (resultType.isScalable()) {
1840  auto llvmTypeConverter = this->getTypeConverter();
1841  auto deinterleaveResults = deinterleaveOp.getResultTypes();
1842  auto packedOpResults =
1843  llvmTypeConverter->packOperationResults(deinterleaveResults);
1844  auto intrinsic = LLVM::vector_deinterleave2::create(
1845  rewriter, loc, packedOpResults, adaptor.getSource());
1846 
1847  auto evenResult = LLVM::ExtractValueOp::create(
1848  rewriter, loc, intrinsic->getResult(0), 0);
1849  auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1850  intrinsic->getResult(0), 1);
1851 
1852  rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1853  return success();
1854  }
1855  // Lower fixed-size deinterleave to two shufflevectors. While the
1856  // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1857  // langref still recommends fixed-vectors use shufflevector, see:
1858  // https://llvm.org/docs/LangRef.html#id889.
1859  int64_t resultVectorSize = resultType.getNumElements();
1860  SmallVector<int32_t> evenShuffleMask;
1861  SmallVector<int32_t> oddShuffleMask;
1862 
1863  evenShuffleMask.reserve(resultVectorSize);
1864  oddShuffleMask.reserve(resultVectorSize);
1865 
1866  for (int i = 0; i < sourceType.getNumElements(); ++i) {
1867  if (i % 2 == 0)
1868  evenShuffleMask.push_back(i);
1869  else
1870  oddShuffleMask.push_back(i);
1871  }
1872 
1873  auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1874  auto evenShuffle = LLVM::ShuffleVectorOp::create(
1875  rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1876  auto oddShuffle = LLVM::ShuffleVectorOp::create(
1877  rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1878 
1879  rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1880  return success();
1881  }
1882 };
1883 
1884 /// Conversion pattern for a `vector.from_elements`.
1885 struct VectorFromElementsLowering
1886  : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1888 
1889  LogicalResult
1890  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1891  ConversionPatternRewriter &rewriter) const override {
1892  Location loc = fromElementsOp.getLoc();
1893  VectorType vectorType = fromElementsOp.getType();
1894  // Only support 1-D vectors. Multi-dimensional vectors should have been
1895  // transformed to 1-D vectors by the vector-to-vector transformations before
1896  // this.
1897  if (vectorType.getRank() > 1)
1898  return rewriter.notifyMatchFailure(fromElementsOp,
1899  "rank > 1 vectors are not supported");
1900  Type llvmType = typeConverter->convertType(vectorType);
1901  Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1902  Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1903  for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1904  auto constIdx =
1905  LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1906  result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1907  val, constIdx);
1908  }
1909  rewriter.replaceOp(fromElementsOp, result);
1910  return success();
1911  }
1912 };
1913 
1914 /// Conversion pattern for a `vector.to_elements`.
1915 struct VectorToElementsLowering
1916  : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1918 
1919  LogicalResult
1920  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1921  ConversionPatternRewriter &rewriter) const override {
1922  Location loc = toElementsOp.getLoc();
1923  auto idxType = typeConverter->convertType(rewriter.getIndexType());
1924  Value source = adaptor.getSource();
1925 
1926  SmallVector<Value> results(toElementsOp->getNumResults());
1927  for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1928  // Create an extractelement operation only for results that are not dead.
1929  if (element.use_empty())
1930  continue;
1931 
1932  auto constIdx = LLVM::ConstantOp::create(
1933  rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1934  auto llvmType = typeConverter->convertType(element.getType());
1935 
1936  Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1937  source, constIdx);
1938  results[idx] = result;
1939  }
1940 
1941  rewriter.replaceOp(toElementsOp, results);
1942  return success();
1943  }
1944 };
1945 
1946 /// Conversion pattern for vector.step.
1947 struct VectorScalableStepOpLowering
1948  : public ConvertOpToLLVMPattern<vector::StepOp> {
1950 
1951  LogicalResult
1952  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1953  ConversionPatternRewriter &rewriter) const override {
1954  auto resultType = cast<VectorType>(stepOp.getType());
1955  if (!resultType.isScalable()) {
1956  return failure();
1957  }
1958  Type llvmType = typeConverter->convertType(stepOp.getType());
1959  rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1960  return success();
1961  }
1962 };
1963 
1964 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
1965 /// semantics to:
1966 /// ```
1967 /// %flattened_a = vector.shape_cast %a
1968 /// %flattened_b = vector.shape_cast %b
1969 /// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1970 /// %d = vector.shape_cast %%flattened_d
1971 /// %e = add %c, %d
1972 /// ```
1973 /// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
1974 //
1975 /// This only kicks in when vectorContractLowering is set to Matmul and
1976 /// the vector.contract op is a row-major matrix multiply.
1977 class ContractionOpToMatmulOpLowering
1978  : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
1979 public:
1980  using MaskableOpRewritePattern::MaskableOpRewritePattern;
1981 
1982  ContractionOpToMatmulOpLowering(
1983  vector::VectorContractLowering vectorContractLowering,
1984  MLIRContext *context, PatternBenefit benefit = 100)
1985  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
1986 
1987  FailureOr<Value>
1988  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
1989  PatternRewriter &rewriter) const override;
1990 };
1991 
1992 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
1993 /// semantics to:
1994 /// ```
1995 /// %mta = maybe_transpose
1996 /// %mtb = maybe_transpose
1997 /// %flattened_a = vector.shape_cast %mta
1998 /// %flattened_b = vector.shape_cast %mtb
1999 /// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
2000 /// %mtd = vector.shape_cast %flattened_d
2001 /// %d = maybe_untranspose %mtd
2002 /// %e = add %c, %d
2003 /// ```
2004 //
2005 /// This only kicks in when vectorContractLowering is set to `Matmul`.
2006 /// vector.transpose operations are inserted if the vector.contract op is not a
2007 /// row-major matrix multiply.
2008 ///
2009 /// Scalable vectors are not supported.
2010 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2011  vector::ContractionOp op, MaskingOpInterface maskOp,
2012  PatternRewriter &rew) const {
2013  // TODO: Support vector.mask.
2014  if (maskOp)
2015  return failure();
2016 
2017  auto iteratorTypes = op.getIteratorTypes().getValue();
2018  if (!isParallelIterator(iteratorTypes[0]) ||
2019  !isParallelIterator(iteratorTypes[1]) ||
2020  !isReductionIterator(iteratorTypes[2]))
2021  return failure();
2022 
2023  Type opResType = op.getType();
2024  VectorType vecType = dyn_cast<VectorType>(opResType);
2025  if (vecType && vecType.isScalable()) {
2026  // Note - this is sufficient to reject all cases with scalable vectors.
2027  return failure();
2028  }
2029 
2030  Type elementType = op.getLhsType().getElementType();
2031  if (!elementType.isIntOrFloat())
2032  return failure();
2033 
2034  Type dstElementType = vecType ? vecType.getElementType() : opResType;
2035  if (elementType != dstElementType)
2036  return failure();
2037 
2038  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2039  // Bail out if the contraction cannot be put in this form.
2040  MLIRContext *ctx = op.getContext();
2041  Location loc = op.getLoc();
2042  AffineExpr m, n, k;
2043  bindDims(rew.getContext(), m, n, k);
2044  // LHS must be A(m, k) or A(k, m).
2045  Value lhs = op.getLhs();
2046  auto lhsMap = op.getIndexingMapsArray()[0];
2047  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
2048  lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0});
2049  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
2050  return failure();
2051 
2052  // RHS must be B(k, n) or B(n, k).
2053  Value rhs = op.getRhs();
2054  auto rhsMap = op.getIndexingMapsArray()[1];
2055  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
2056  rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0});
2057  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
2058  return failure();
2059 
2060  // At this point lhs and rhs are in row-major.
2061  VectorType lhsType = cast<VectorType>(lhs.getType());
2062  VectorType rhsType = cast<VectorType>(rhs.getType());
2063  int64_t lhsRows = lhsType.getDimSize(0);
2064  int64_t lhsColumns = lhsType.getDimSize(1);
2065  int64_t rhsColumns = rhsType.getDimSize(1);
2066 
2067  Type flattenedLHSType =
2068  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2069  lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
2070 
2071  Type flattenedRHSType =
2072  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2073  rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
2074 
2075  Value mul = LLVM::MatrixMultiplyOp::create(
2076  rew, loc,
2077  VectorType::get(lhsRows * rhsColumns,
2078  cast<VectorType>(lhs.getType()).getElementType()),
2079  lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2080 
2081  mul = vector::ShapeCastOp::create(
2082  rew, loc,
2083  VectorType::get({lhsRows, rhsColumns},
2084  getElementTypeOrSelf(op.getAcc().getType())),
2085  mul);
2086 
2087  // ACC must be C(m, n) or C(n, m).
2088  auto accMap = op.getIndexingMapsArray()[2];
2089  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
2090  mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0});
2091  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
2092  llvm_unreachable("invalid contraction semantics");
2093 
2094  Value res = isa<IntegerType>(elementType)
2095  ? static_cast<Value>(
2096  arith::AddIOp::create(rew, loc, op.getAcc(), mul))
2097  : static_cast<Value>(
2098  arith::AddFOp::create(rew, loc, op.getAcc(), mul));
2099 
2100  return res;
2101 }
2102 
2103 /// Lowers vector.transpose to llvm.intr.matrix.transpose
2104 class TransposeOpToMatrixTransposeOpLowering
2105  : public OpRewritePattern<vector::TransposeOp> {
2106 public:
2108 
2109  LogicalResult matchAndRewrite(vector::TransposeOp op,
2110  PatternRewriter &rewriter) const override {
2111  auto loc = op.getLoc();
2112 
2113  Value input = op.getVector();
2114  VectorType inputType = op.getSourceVectorType();
2115  VectorType resType = op.getResultVectorType();
2116 
2117  if (inputType.isScalable())
2118  return rewriter.notifyMatchFailure(
2119  op, "This lowering does not support scalable vectors");
2120 
2121  // Set up convenience transposition table.
2122  ArrayRef<int64_t> transp = op.getPermutation();
2123 
2124  if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2125  return failure();
2126  }
2127 
2128  Type flattenedType =
2129  VectorType::get(resType.getNumElements(), resType.getElementType());
2130  auto matrix =
2131  vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2132  auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
2133  auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
2134  Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2135  matrix, rows, columns);
2136  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
2137  return success();
2138  }
2139 };
2140 
2141 /// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
2142 /// `vector.broadcast` through other patterns.
2143 struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
2145  LogicalResult
2146  matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
2147  ConversionPatternRewriter &rewriter) const override {
2148  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
2149  adaptor.getInput());
2150  return success();
2151  }
2152 };
2153 
2154 } // namespace
2155 
2158  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
2159 }
2160 
2163  patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit);
2164 }
2165 
2168  patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(),
2169  benefit);
2170 }
2171 
2172 /// Populate the given list with patterns that convert from Vector to LLVM.
2174  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2175  bool reassociateFPReductions, bool force32BitVectorIndices,
2176  bool useVectorAlignment) {
2177  // This function populates only ConversionPatterns, not RewritePatterns.
2178  MLIRContext *ctx = converter.getDialect()->getContext();
2179  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2180  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2181  patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2182  VectorLoadStoreConversion<vector::MaskedLoadOp>,
2183  VectorLoadStoreConversion<vector::StoreOp>,
2184  VectorLoadStoreConversion<vector::MaskedStoreOp>,
2185  VectorGatherOpConversion, VectorScatterOpConversion>(
2186  converter, useVectorAlignment);
2187  patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2188  VectorExtractOpConversion, VectorFMAOp1DConversion,
2189  VectorInsertOpConversion, VectorPrintOpConversion,
2190  VectorTypeCastOpConversion, VectorScaleOpConversion,
2191  VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2192  VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
2193  VectorBroadcastScalarToNdLowering,
2194  VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2195  MaskedReductionOpConversion, VectorInterleaveOpLowering,
2196  VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2197  VectorToElementsLowering, VectorScalableStepOpLowering>(
2198  converter);
2199 }
2200 
2201 namespace {
2202 struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2204  void loadDependentDialects(MLIRContext *context) const final {
2205  context->loadDialect<LLVM::LLVMDialect>();
2206  }
2207 
2208  /// Hook for derived dialect interface to provide conversion patterns
2209  /// and mark dialect legal for the conversion target.
2210  void populateConvertToLLVMConversionPatterns(
2211  ConversionTarget &target, LLVMTypeConverter &typeConverter,
2212  RewritePatternSet &patterns) const final {
2214  }
2215 };
2216 } // namespace
2217 
2219  DialectRegistry &registry) {
2220  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2221  dialect->addInterfaces<VectorToLLVMDialectInterface>();
2222  });
2223 }
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
@ None
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:18
int64_t rows
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:300
This class helps build Operations.
Definition: Builders.h:205
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:840
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:154
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:149
void registerConvertVectorToLLVMInterface(DialectRegistry &registry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:119
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163