MLIR  22.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/TypeUtilities.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/Support/Casting.h"
34 
35 #include <optional>
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 // Helper that picks the proper sequence for inserting.
42  const LLVMTypeConverter &typeConverter, Location loc,
43  Value val1, Value val2, Type llvmType, int64_t rank,
44  int64_t pos) {
45  assert(rank > 0 && "0-D vector corner case should have been handled already");
46  if (rank == 1) {
47  auto idxType = rewriter.getIndexType();
48  auto constant = LLVM::ConstantOp::create(
49  rewriter, loc, typeConverter.convertType(idxType),
50  rewriter.getIntegerAttr(idxType, pos));
51  return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
52  constant);
53  }
54  return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
55 }
56 
57 // Helper that picks the proper sequence for extracting.
59  const LLVMTypeConverter &typeConverter, Location loc,
60  Value val, Type llvmType, int64_t rank, int64_t pos) {
61  if (rank <= 1) {
62  auto idxType = rewriter.getIndexType();
63  auto constant = LLVM::ConstantOp::create(
64  rewriter, loc, typeConverter.convertType(idxType),
65  rewriter.getIntegerAttr(idxType, pos));
66  return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
67  constant);
68  }
69  return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
70 }
71 
72 // Helper that returns data layout alignment of a vector.
73 LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
74  VectorType vectorType, unsigned &align) {
75  Type convertedVectorTy = typeConverter.convertType(vectorType);
76  if (!convertedVectorTy)
77  return failure();
78 
79  llvm::LLVMContext llvmContext;
80  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81  .getPreferredAlignment(convertedVectorTy,
82  typeConverter.getDataLayout());
83 
84  return success();
85 }
86 
87 // Helper that returns data layout alignment of a memref.
88 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
89  MemRefType memrefType, unsigned &align) {
90  Type elementTy = typeConverter.convertType(memrefType.getElementType());
91  if (!elementTy)
92  return failure();
93 
94  // TODO: this should use the MLIR data layout when it becomes available and
95  // stop depending on translation.
96  llvm::LLVMContext llvmContext;
97  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
98  .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
99  return success();
100 }
101 
102 // Helper to resolve the alignment for vector load/store, gather and scatter
103 // ops. If useVectorAlignment is true, get the preferred alignment for the
104 // vector type in the operation. This option is used for hardware backends with
105 // vectorization. Otherwise, use the preferred alignment of the element type of
106 // the memref. Note that if you choose to use vector alignment, the shape of the
107 // vector type must be resolved before the ConvertVectorToLLVM pass is run.
108 LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
109  VectorType vectorType,
110  MemRefType memrefType, unsigned &align,
111  bool useVectorAlignment) {
112  if (useVectorAlignment) {
113  if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
114  return failure();
115  }
116  } else {
117  if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
118  return failure();
119  }
120  }
121  return success();
122 }
123 
124 // Check if the last stride is non-unit and has a valid memory space.
125 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
126  const LLVMTypeConverter &converter) {
127  if (!memRefType.isLastDimUnitStride())
128  return failure();
129  if (failed(converter.getMemRefAddressSpace(memRefType)))
130  return failure();
131  return success();
132 }
133 
134 // Add an index vector component to a base pointer.
136  const LLVMTypeConverter &typeConverter,
137  MemRefType memRefType, Value llvmMemref, Value base,
138  Value index, VectorType vectorType) {
139  assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
140  "unsupported memref type");
141  assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
142  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
143  auto ptrsType =
144  LLVM::getVectorType(pType, vectorType.getDimSize(0),
145  /*isScalable=*/vectorType.getScalableDims()[0]);
146  return LLVM::GEPOp::create(
147  rewriter, loc, ptrsType,
148  typeConverter.convertType(memRefType.getElementType()), base, index);
149 }
150 
151 /// Convert `foldResult` into a Value. Integer attribute is converted to
152 /// an LLVM constant op.
153 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
154  OpFoldResult foldResult) {
155  if (auto attr = dyn_cast<Attribute>(foldResult)) {
156  auto intAttr = cast<IntegerAttr>(attr);
157  return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
158  }
159 
160  return cast<Value>(foldResult);
161 }
162 
163 namespace {
164 
165 /// Trivial Vector to LLVM conversions
166 using VectorScaleOpConversion =
168 
169 /// Conversion pattern for a vector.bitcast.
170 class VectorBitCastOpConversion
171  : public ConvertOpToLLVMPattern<vector::BitCastOp> {
172 public:
174 
175  LogicalResult
176  matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  // Only 0-D and 1-D vectors can be lowered to LLVM.
179  VectorType resultTy = bitCastOp.getResultVectorType();
180  if (resultTy.getRank() > 1)
181  return failure();
182  Type newResultTy = typeConverter->convertType(resultTy);
183  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184  adaptor.getOperands()[0]);
185  return success();
186  }
187 };
188 
189 /// Overloaded utility that replaces a vector.load, vector.store,
190 /// vector.maskedload and vector.maskedstore with their respective LLVM
191 /// couterparts.
192 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193  vector::LoadOpAdaptor adaptor,
194  VectorType vectorTy, Value ptr, unsigned align,
195  ConversionPatternRewriter &rewriter) {
196  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
197  /*volatile_=*/false,
198  loadOp.getNontemporal());
199 }
200 
201 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202  vector::MaskedLoadOpAdaptor adaptor,
203  VectorType vectorTy, Value ptr, unsigned align,
204  ConversionPatternRewriter &rewriter) {
205  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206  loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
207 }
208 
209 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210  vector::StoreOpAdaptor adaptor,
211  VectorType vectorTy, Value ptr, unsigned align,
212  ConversionPatternRewriter &rewriter) {
213  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
214  ptr, align, /*volatile_=*/false,
215  storeOp.getNontemporal());
216 }
217 
218 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219  vector::MaskedStoreOpAdaptor adaptor,
220  VectorType vectorTy, Value ptr, unsigned align,
221  ConversionPatternRewriter &rewriter) {
222  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223  storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
224 }
225 
226 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
227 /// vector.maskedstore.
228 template <class LoadOrStoreOp>
229 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
230 public:
231  explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
232  bool useVectorAlign)
233  : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234  useVectorAlignment(useVectorAlign) {}
236 
237  LogicalResult
238  matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239  typename LoadOrStoreOp::Adaptor adaptor,
240  ConversionPatternRewriter &rewriter) const override {
241  // Only 1-D vectors can be lowered to LLVM.
242  VectorType vectorTy = loadOrStoreOp.getVectorType();
243  if (vectorTy.getRank() > 1)
244  return failure();
245 
246  auto loc = loadOrStoreOp->getLoc();
247  MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248 
249  // Resolve alignment.
250  // Explicit alignment takes priority over use-vector-alignment.
251  unsigned align = loadOrStoreOp.getAlignment().value_or(0);
252  if (!align &&
253  failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
254  memRefTy, align, useVectorAlignment)))
255  return rewriter.notifyMatchFailure(loadOrStoreOp,
256  "could not resolve alignment");
257 
258  // Resolve address.
259  auto vtype = cast<VectorType>(
260  this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
261  Value dataPtr = this->getStridedElementPtr(
262  rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263  replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
264  rewriter);
265  return success();
266  }
267 
268 private:
269  // If true, use the preferred alignment of the vector type.
270  // If false, use the preferred alignment of the element type
271  // of the memref. This flag is intended for use with hardware
272  // backends that require alignment of vector operations.
273  const bool useVectorAlignment;
274 };
275 
276 /// Conversion pattern for a vector.gather.
277 class VectorGatherOpConversion
278  : public ConvertOpToLLVMPattern<vector::GatherOp> {
279 public:
280  explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
281  bool useVectorAlign)
282  : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283  useVectorAlignment(useVectorAlign) {}
285 
286  LogicalResult
287  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288  ConversionPatternRewriter &rewriter) const override {
289  Location loc = gather->getLoc();
290  MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291  assert(memRefType && "The base should be bufferized");
292 
293  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
294  return rewriter.notifyMatchFailure(gather, "memref type not supported");
295 
296  VectorType vType = gather.getVectorType();
297  if (vType.getRank() > 1) {
298  return rewriter.notifyMatchFailure(
299  gather, "only 1-D vectors can be lowered to LLVM");
300  }
301 
302  // Resolve alignment.
303  // Explicit alignment takes priority over use-vector-alignment.
304  unsigned align = gather.getAlignment().value_or(0);
305  if (!align &&
306  failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
307  memRefType, align, useVectorAlignment)))
308  return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
309 
310  // Resolve address.
311  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
312  adaptor.getBase(), adaptor.getOffsets());
313  Value base = adaptor.getBase();
314  Value ptrs =
315  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
316  base, ptr, adaptor.getIndices(), vType);
317 
318  // Replace with the gather intrinsic.
319  rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
320  gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
321  adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
322  return success();
323  }
324 
325 private:
326  // If true, use the preferred alignment of the vector type.
327  // If false, use the preferred alignment of the element type
328  // of the memref. This flag is intended for use with hardware
329  // backends that require alignment of vector operations.
330  const bool useVectorAlignment;
331 };
332 
333 /// Conversion pattern for a vector.scatter.
334 class VectorScatterOpConversion
335  : public ConvertOpToLLVMPattern<vector::ScatterOp> {
336 public:
337  explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
338  bool useVectorAlign)
339  : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
340  useVectorAlignment(useVectorAlign) {}
341 
343 
344  LogicalResult
345  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
346  ConversionPatternRewriter &rewriter) const override {
347  auto loc = scatter->getLoc();
348  MemRefType memRefType = scatter.getMemRefType();
349 
350  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
351  return rewriter.notifyMatchFailure(scatter, "memref type not supported");
352 
353  VectorType vType = scatter.getVectorType();
354  if (vType.getRank() > 1) {
355  return rewriter.notifyMatchFailure(
356  scatter, "only 1-D vectors can be lowered to LLVM");
357  }
358 
359  // Resolve alignment.
360  // Explicit alignment takes priority over use-vector-alignment.
361  unsigned align = scatter.getAlignment().value_or(0);
362  if (!align &&
363  failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
364  memRefType, align, useVectorAlignment)))
365  return rewriter.notifyMatchFailure(scatter,
366  "could not resolve alignment");
367 
368  // Resolve address.
369  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
370  adaptor.getBase(), adaptor.getOffsets());
371  Value ptrs =
372  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
373  adaptor.getBase(), ptr, adaptor.getIndices(), vType);
374 
375  // Replace with the scatter intrinsic.
376  rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
377  scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
378  rewriter.getI32IntegerAttr(align));
379  return success();
380  }
381 
382 private:
383  // If true, use the preferred alignment of the vector type.
384  // If false, use the preferred alignment of the element type
385  // of the memref. This flag is intended for use with hardware
386  // backends that require alignment of vector operations.
387  const bool useVectorAlignment;
388 };
389 
390 /// Conversion pattern for a vector.expandload.
391 class VectorExpandLoadOpConversion
392  : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
393 public:
395 
396  LogicalResult
397  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
398  ConversionPatternRewriter &rewriter) const override {
399  auto loc = expand->getLoc();
400  MemRefType memRefType = expand.getMemRefType();
401 
402  // Resolve address.
403  auto vtype = typeConverter->convertType(expand.getVectorType());
404  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
405  adaptor.getBase(), adaptor.getIndices());
406 
407  // From:
408  // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
409  // The pointer alignment defaults to 1.
410  uint64_t alignment = expand.getAlignment().value_or(1);
411 
412  rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
413  expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
414  alignment);
415  return success();
416  }
417 };
418 
419 /// Conversion pattern for a vector.compressstore.
420 class VectorCompressStoreOpConversion
421  : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
422 public:
424 
425  LogicalResult
426  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
427  ConversionPatternRewriter &rewriter) const override {
428  auto loc = compress->getLoc();
429  MemRefType memRefType = compress.getMemRefType();
430 
431  // Resolve address.
432  Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
433  adaptor.getBase(), adaptor.getIndices());
434 
435  // From:
436  // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
437  // The pointer alignment defaults to 1.
438  uint64_t alignment = compress.getAlignment().value_or(1);
439 
440  rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
441  compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
442  return success();
443  }
444 };
445 
446 /// Reduction neutral classes for overloading.
447 class ReductionNeutralZero {};
448 class ReductionNeutralIntOne {};
449 class ReductionNeutralFPOne {};
450 class ReductionNeutralAllOnes {};
451 class ReductionNeutralSIntMin {};
452 class ReductionNeutralUIntMin {};
453 class ReductionNeutralSIntMax {};
454 class ReductionNeutralUIntMax {};
455 class ReductionNeutralFPMin {};
456 class ReductionNeutralFPMax {};
457 
458 /// Create the reduction neutral zero value.
459 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
460  ConversionPatternRewriter &rewriter,
461  Location loc, Type llvmType) {
462  return LLVM::ConstantOp::create(rewriter, loc, llvmType,
463  rewriter.getZeroAttr(llvmType));
464 }
465 
466 /// Create the reduction neutral integer one value.
467 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
468  ConversionPatternRewriter &rewriter,
469  Location loc, Type llvmType) {
470  return LLVM::ConstantOp::create(rewriter, loc, llvmType,
471  rewriter.getIntegerAttr(llvmType, 1));
472 }
473 
474 /// Create the reduction neutral fp one value.
475 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
476  ConversionPatternRewriter &rewriter,
477  Location loc, Type llvmType) {
478  return LLVM::ConstantOp::create(rewriter, loc, llvmType,
479  rewriter.getFloatAttr(llvmType, 1.0));
480 }
481 
482 /// Create the reduction neutral all-ones value.
483 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
484  ConversionPatternRewriter &rewriter,
485  Location loc, Type llvmType) {
486  return LLVM::ConstantOp::create(
487  rewriter, loc, llvmType,
488  rewriter.getIntegerAttr(
489  llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
490 }
491 
492 /// Create the reduction neutral signed int minimum value.
493 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
494  ConversionPatternRewriter &rewriter,
495  Location loc, Type llvmType) {
496  return LLVM::ConstantOp::create(
497  rewriter, loc, llvmType,
498  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
499  llvmType.getIntOrFloatBitWidth())));
500 }
501 
502 /// Create the reduction neutral unsigned int minimum value.
503 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
504  ConversionPatternRewriter &rewriter,
505  Location loc, Type llvmType) {
506  return LLVM::ConstantOp::create(
507  rewriter, loc, llvmType,
508  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
509  llvmType.getIntOrFloatBitWidth())));
510 }
511 
512 /// Create the reduction neutral signed int maximum value.
513 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
514  ConversionPatternRewriter &rewriter,
515  Location loc, Type llvmType) {
516  return LLVM::ConstantOp::create(
517  rewriter, loc, llvmType,
518  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
519  llvmType.getIntOrFloatBitWidth())));
520 }
521 
522 /// Create the reduction neutral unsigned int maximum value.
523 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
524  ConversionPatternRewriter &rewriter,
525  Location loc, Type llvmType) {
526  return LLVM::ConstantOp::create(
527  rewriter, loc, llvmType,
528  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
529  llvmType.getIntOrFloatBitWidth())));
530 }
531 
532 /// Create the reduction neutral fp minimum value.
533 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
534  ConversionPatternRewriter &rewriter,
535  Location loc, Type llvmType) {
536  auto floatType = cast<FloatType>(llvmType);
537  return LLVM::ConstantOp::create(
538  rewriter, loc, llvmType,
539  rewriter.getFloatAttr(
540  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
541  /*Negative=*/false)));
542 }
543 
544 /// Create the reduction neutral fp maximum value.
545 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
546  ConversionPatternRewriter &rewriter,
547  Location loc, Type llvmType) {
548  auto floatType = cast<FloatType>(llvmType);
549  return LLVM::ConstantOp::create(
550  rewriter, loc, llvmType,
551  rewriter.getFloatAttr(
552  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
553  /*Negative=*/true)));
554 }
555 
556 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
557 /// returns a new accumulator value using `ReductionNeutral`.
558 template <class ReductionNeutral>
559 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
560  Location loc, Type llvmType,
561  Value accumulator) {
562  if (accumulator)
563  return accumulator;
564 
565  return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
566  llvmType);
567 }
568 
569 /// Creates a value with the 1-D vector shape provided in `llvmType`.
570 /// This is used as effective vector length by some intrinsics supporting
571 /// dynamic vector lengths at runtime.
572 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
573  Location loc, Type llvmType) {
574  VectorType vType = cast<VectorType>(llvmType);
575  auto vShape = vType.getShape();
576  assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
577 
578  Value baseVecLength = LLVM::ConstantOp::create(
579  rewriter, loc, rewriter.getI32Type(),
580  rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
581 
582  if (!vType.getScalableDims()[0])
583  return baseVecLength;
584 
585  // For a scalable vector type, create and return `vScale * baseVecLength`.
586  Value vScale = vector::VectorScaleOp::create(rewriter, loc);
587  vScale =
588  arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
589  Value scalableVecLength =
590  arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
591  return scalableVecLength;
592 }
593 
594 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
595 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
596 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
597 /// non-null.
598 template <class LLVMRedIntrinOp, class ScalarOp>
599 static Value createIntegerReductionArithmeticOpLowering(
600  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
601  Value vectorOperand, Value accumulator) {
602 
603  Value result =
604  LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
605 
606  if (accumulator)
607  result = ScalarOp::create(rewriter, loc, accumulator, result);
608  return result;
609 }
610 
611 /// Helper method to lower a `vector.reduction` operation that performs
612 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
613 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
614 /// the accumulator value if non-null.
615 template <class LLVMRedIntrinOp>
616 static Value createIntegerReductionComparisonOpLowering(
617  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
618  Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
619  Value result =
620  LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
621  if (accumulator) {
622  Value cmp =
623  LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
624  result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
625  }
626  return result;
627 }
628 
629 namespace {
630 template <typename Source>
631 struct VectorToScalarMapper;
632 template <>
633 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
634  using Type = LLVM::MaximumOp;
635 };
636 template <>
637 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
638  using Type = LLVM::MinimumOp;
639 };
640 template <>
641 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
642  using Type = LLVM::MaxNumOp;
643 };
644 template <>
645 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
646  using Type = LLVM::MinNumOp;
647 };
648 } // namespace
649 
650 template <class LLVMRedIntrinOp>
651 static Value createFPReductionComparisonOpLowering(
652  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
653  Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
654  Value result =
655  LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
656 
657  if (accumulator) {
658  result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
659  rewriter, loc, result, accumulator);
660  }
661 
662  return result;
663 }
664 
665 /// Reduction neutral classes for overloading
666 class MaskNeutralFMaximum {};
667 class MaskNeutralFMinimum {};
668 
669 /// Get the mask neutral floating point maximum value
670 static llvm::APFloat
671 getMaskNeutralValue(MaskNeutralFMaximum,
672  const llvm::fltSemantics &floatSemantics) {
673  return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
674 }
675 /// Get the mask neutral floating point minimum value
676 static llvm::APFloat
677 getMaskNeutralValue(MaskNeutralFMinimum,
678  const llvm::fltSemantics &floatSemantics) {
679  return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
680 }
681 
682 /// Create the mask neutral floating point MLIR vector constant
683 template <typename MaskNeutral>
684 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
685  Location loc, Type llvmType,
686  Type vectorType) {
687  const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
688  auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
689  auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
690  return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
691 }
692 
693 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
694 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
695 /// `fmaximum`/`fminimum`.
696 /// More information: https://github.com/llvm/llvm-project/issues/64940
697 template <class LLVMRedIntrinOp, class MaskNeutral>
698 static Value
699 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
700  Location loc, Type llvmType,
701  Value vectorOperand, Value accumulator,
702  Value mask, LLVM::FastmathFlagsAttr fmf) {
703  const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
704  rewriter, loc, llvmType, vectorOperand.getType());
705  const Value selectedVectorByMask = LLVM::SelectOp::create(
706  rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
707  return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
708  rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
709 }
710 
711 template <class LLVMRedIntrinOp, class ReductionNeutral>
712 static Value
713 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
714  Type llvmType, Value vectorOperand,
715  Value accumulator, LLVM::FastmathFlagsAttr fmf) {
716  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
717  llvmType, accumulator);
718  return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
719  /*startValue=*/accumulator, vectorOperand,
720  fmf);
721 }
722 
723 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
724 /// that requires a start value. This start value format spans across fp
725 /// reductions without mask and all the masked reduction intrinsics.
726 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
727 static Value
728 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
729  Location loc, Type llvmType,
730  Value vectorOperand, Value accumulator) {
731  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
732  llvmType, accumulator);
733  return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
734  /*startValue=*/accumulator, vectorOperand);
735 }
736 
737 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
738 static Value lowerPredicatedReductionWithStartValue(
739  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
740  Value vectorOperand, Value accumulator, Value mask) {
741  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
742  llvmType, accumulator);
743  Value vectorLength =
744  createVectorLengthValue(rewriter, loc, vectorOperand.getType());
745  return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
746  /*startValue=*/accumulator, vectorOperand,
747  mask, vectorLength);
748 }
749 
750 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
751  class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
752 static Value lowerPredicatedReductionWithStartValue(
753  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
754  Value vectorOperand, Value accumulator, Value mask) {
755  if (llvmType.isIntOrIndex())
756  return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
757  IntReductionNeutral>(
758  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
759 
760  // FP dispatch.
761  return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
762  FPReductionNeutral>(
763  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
764 }
765 
766 /// Conversion pattern for all vector reductions.
767 class VectorReductionOpConversion
768  : public ConvertOpToLLVMPattern<vector::ReductionOp> {
769 public:
770  explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
771  bool reassociateFPRed)
772  : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
773  reassociateFPReductions(reassociateFPRed) {}
774 
775  LogicalResult
776  matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
777  ConversionPatternRewriter &rewriter) const override {
778  auto kind = reductionOp.getKind();
779  Type eltType = reductionOp.getDest().getType();
780  Type llvmType = typeConverter->convertType(eltType);
781  Value operand = adaptor.getVector();
782  Value acc = adaptor.getAcc();
783  Location loc = reductionOp.getLoc();
784 
785  if (eltType.isIntOrIndex()) {
786  // Integer reductions: add/mul/min/max/and/or/xor.
787  Value result;
788  switch (kind) {
789  case vector::CombiningKind::ADD:
790  result =
791  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
792  LLVM::AddOp>(
793  rewriter, loc, llvmType, operand, acc);
794  break;
795  case vector::CombiningKind::MUL:
796  result =
797  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
798  LLVM::MulOp>(
799  rewriter, loc, llvmType, operand, acc);
800  break;
802  result = createIntegerReductionComparisonOpLowering<
803  LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
804  LLVM::ICmpPredicate::ule);
805  break;
806  case vector::CombiningKind::MINSI:
807  result = createIntegerReductionComparisonOpLowering<
808  LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
809  LLVM::ICmpPredicate::sle);
810  break;
811  case vector::CombiningKind::MAXUI:
812  result = createIntegerReductionComparisonOpLowering<
813  LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
814  LLVM::ICmpPredicate::uge);
815  break;
816  case vector::CombiningKind::MAXSI:
817  result = createIntegerReductionComparisonOpLowering<
818  LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
819  LLVM::ICmpPredicate::sge);
820  break;
821  case vector::CombiningKind::AND:
822  result =
823  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
824  LLVM::AndOp>(
825  rewriter, loc, llvmType, operand, acc);
826  break;
827  case vector::CombiningKind::OR:
828  result =
829  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
830  LLVM::OrOp>(
831  rewriter, loc, llvmType, operand, acc);
832  break;
833  case vector::CombiningKind::XOR:
834  result =
835  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
836  LLVM::XOrOp>(
837  rewriter, loc, llvmType, operand, acc);
838  break;
839  default:
840  return failure();
841  }
842  rewriter.replaceOp(reductionOp, result);
843 
844  return success();
845  }
846 
847  if (!isa<FloatType>(eltType))
848  return failure();
849 
850  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
851  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
852  reductionOp.getContext(),
853  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
855  reductionOp.getContext(),
856  fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
857  : LLVM::FastmathFlags::none));
858 
859  // Floating-point reductions: add/mul/min/max
860  Value result;
861  if (kind == vector::CombiningKind::ADD) {
862  result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
863  ReductionNeutralZero>(
864  rewriter, loc, llvmType, operand, acc, fmf);
865  } else if (kind == vector::CombiningKind::MUL) {
866  result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
867  ReductionNeutralFPOne>(
868  rewriter, loc, llvmType, operand, acc, fmf);
869  } else if (kind == vector::CombiningKind::MINIMUMF) {
870  result =
871  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
872  rewriter, loc, llvmType, operand, acc, fmf);
873  } else if (kind == vector::CombiningKind::MAXIMUMF) {
874  result =
875  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
876  rewriter, loc, llvmType, operand, acc, fmf);
877  } else if (kind == vector::CombiningKind::MINNUMF) {
878  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
879  rewriter, loc, llvmType, operand, acc, fmf);
880  } else if (kind == vector::CombiningKind::MAXNUMF) {
881  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
882  rewriter, loc, llvmType, operand, acc, fmf);
883  } else {
884  return failure();
885  }
886 
887  rewriter.replaceOp(reductionOp, result);
888  return success();
889  }
890 
891 private:
892  const bool reassociateFPReductions;
893 };
894 
895 /// Base class to convert a `vector.mask` operation while matching traits
896 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
897 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
898 /// method performs a second match against the maskable operation `MaskedOp`.
899 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
900 /// implemented by the concrete conversion classes. This method can match
901 /// against specific traits of the `vector.mask` and the maskable operation. It
902 /// must replace the `vector.mask` operation.
903 template <class MaskedOp>
904 class VectorMaskOpConversionBase
905  : public ConvertOpToLLVMPattern<vector::MaskOp> {
906 public:
908 
909  LogicalResult
910  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
911  ConversionPatternRewriter &rewriter) const final {
912  // Match against the maskable operation kind.
913  auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
914  if (!maskedOp)
915  return failure();
916  return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
917  }
918 
919 protected:
920  virtual LogicalResult
921  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
922  vector::MaskableOpInterface maskableOp,
923  ConversionPatternRewriter &rewriter) const = 0;
924 };
925 
926 class MaskedReductionOpConversion
927  : public VectorMaskOpConversionBase<vector::ReductionOp> {
928 
929 public:
930  using VectorMaskOpConversionBase<
931  vector::ReductionOp>::VectorMaskOpConversionBase;
932 
933  LogicalResult matchAndRewriteMaskableOp(
934  vector::MaskOp maskOp, MaskableOpInterface maskableOp,
935  ConversionPatternRewriter &rewriter) const override {
936  auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
937  auto kind = reductionOp.getKind();
938  Type eltType = reductionOp.getDest().getType();
939  Type llvmType = typeConverter->convertType(eltType);
940  Value operand = reductionOp.getVector();
941  Value acc = reductionOp.getAcc();
942  Location loc = reductionOp.getLoc();
943 
944  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
945  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
946  reductionOp.getContext(),
947  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
948 
949  Value result;
950  switch (kind) {
951  case vector::CombiningKind::ADD:
952  result = lowerPredicatedReductionWithStartValue<
953  LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
954  ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
955  maskOp.getMask());
956  break;
957  case vector::CombiningKind::MUL:
958  result = lowerPredicatedReductionWithStartValue<
959  LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
960  ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
961  maskOp.getMask());
962  break;
964  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
965  ReductionNeutralUIntMax>(
966  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
967  break;
968  case vector::CombiningKind::MINSI:
969  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
970  ReductionNeutralSIntMax>(
971  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
972  break;
973  case vector::CombiningKind::MAXUI:
974  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
975  ReductionNeutralUIntMin>(
976  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
977  break;
978  case vector::CombiningKind::MAXSI:
979  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
980  ReductionNeutralSIntMin>(
981  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
982  break;
983  case vector::CombiningKind::AND:
984  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
985  ReductionNeutralAllOnes>(
986  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
987  break;
988  case vector::CombiningKind::OR:
989  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
990  ReductionNeutralZero>(
991  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
992  break;
993  case vector::CombiningKind::XOR:
994  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
995  ReductionNeutralZero>(
996  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
997  break;
998  case vector::CombiningKind::MINNUMF:
999  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1000  ReductionNeutralFPMax>(
1001  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1002  break;
1003  case vector::CombiningKind::MAXNUMF:
1004  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1005  ReductionNeutralFPMin>(
1006  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1007  break;
1008  case CombiningKind::MAXIMUMF:
1009  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1010  MaskNeutralFMaximum>(
1011  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1012  break;
1013  case CombiningKind::MINIMUMF:
1014  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1015  MaskNeutralFMinimum>(
1016  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1017  break;
1018  }
1019 
1020  // Replace `vector.mask` operation altogether.
1021  rewriter.replaceOp(maskOp, result);
1022  return success();
1023  }
1024 };
1025 
1026 class VectorShuffleOpConversion
1027  : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1028 public:
1030 
1031  LogicalResult
1032  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1033  ConversionPatternRewriter &rewriter) const override {
1034  auto loc = shuffleOp->getLoc();
1035  auto v1Type = shuffleOp.getV1VectorType();
1036  auto v2Type = shuffleOp.getV2VectorType();
1037  auto vectorType = shuffleOp.getResultVectorType();
1038  Type llvmType = typeConverter->convertType(vectorType);
1039  ArrayRef<int64_t> mask = shuffleOp.getMask();
1040 
1041  // Bail if result type cannot be lowered.
1042  if (!llvmType)
1043  return failure();
1044 
1045  // Get rank and dimension sizes.
1046  int64_t rank = vectorType.getRank();
1047 #ifndef NDEBUG
1048  bool wellFormed0DCase =
1049  v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1050  bool wellFormedNDCase =
1051  v1Type.getRank() == rank && v2Type.getRank() == rank;
1052  assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1053 #endif
1054 
1055  // For rank 0 and 1, where both operands have *exactly* the same vector
1056  // type, there is direct shuffle support in LLVM. Use it!
1057  if (rank <= 1 && v1Type == v2Type) {
1058  Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1059  rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1060  llvm::to_vector_of<int32_t>(mask));
1061  rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1062  return success();
1063  }
1064 
1065  // For all other cases, insert the individual values individually.
1066  int64_t v1Dim = v1Type.getDimSize(0);
1067  Type eltType;
1068  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1069  eltType = arrayType.getElementType();
1070  else
1071  eltType = cast<VectorType>(llvmType).getElementType();
1072  Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1073  int64_t insPos = 0;
1074  for (int64_t extPos : mask) {
1075  Value value = adaptor.getV1();
1076  if (extPos >= v1Dim) {
1077  extPos -= v1Dim;
1078  value = adaptor.getV2();
1079  }
1080  Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1081  eltType, rank, extPos);
1082  insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1083  llvmType, rank, insPos++);
1084  }
1085  rewriter.replaceOp(shuffleOp, insert);
1086  return success();
1087  }
1088 };
1089 
1090 class VectorExtractOpConversion
1091  : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1092 public:
1094 
1095  LogicalResult
1096  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1097  ConversionPatternRewriter &rewriter) const override {
1098  auto loc = extractOp->getLoc();
1099  auto resultType = extractOp.getResult().getType();
1100  auto llvmResultType = typeConverter->convertType(resultType);
1101  // Bail if result type cannot be lowered.
1102  if (!llvmResultType)
1103  return failure();
1104 
1106  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1107 
1108  // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1109  // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1110  // from a N-d vector extract to a nested aggregate vector extract in two
1111  // steps:
1112  // - Extract a member from the nested aggregate. The result can be
1113  // a lower rank nested aggregate or a vector (1-D). This is done using
1114  // `llvm.extractvalue`.
1115  // - Extract a scalar out of the vector if needed. This is done using
1116  // `llvm.extractelement`.
1117 
1118  // Determine if we need to extract a member out of the aggregate. We
1119  // always need to extract a member if the input rank >= 2.
1120  bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1121  // Determine if we need to extract a scalar as the result. We extract
1122  // a scalar if the extract is full rank, i.e., the number of indices is
1123  // equal to source vector rank.
1124  bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1125  extractOp.getSourceVectorType().getRank();
1126 
1127  // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1128  // need to add a position for this change.
1129  if (extractOp.getSourceVectorType().getRank() == 0) {
1130  Type idxType = typeConverter->convertType(rewriter.getIndexType());
1131  positionVec.push_back(rewriter.getZeroAttr(idxType));
1132  }
1133 
1134  Value extracted = adaptor.getSource();
1135  if (extractsAggregate) {
1136  ArrayRef<OpFoldResult> position(positionVec);
1137  if (extractsScalar) {
1138  // If we are extracting a scalar from the extracted member, we drop
1139  // the last index, which will be used to extract the scalar out of the
1140  // vector.
1141  position = position.drop_back();
1142  }
1143  // llvm.extractvalue does not support dynamic dimensions.
1144  if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1145  return failure();
1146  }
1147  extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1148  getAsIntegers(position));
1149  }
1150 
1151  if (extractsScalar) {
1152  extracted = LLVM::ExtractElementOp::create(
1153  rewriter, loc, extracted,
1154  getAsLLVMValue(rewriter, loc, positionVec.back()));
1155  }
1156 
1157  rewriter.replaceOp(extractOp, extracted);
1158  return success();
1159  }
1160 };
1161 
1162 /// Conversion pattern that turns a vector.fma on a 1-D vector
1163 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1164 /// This does not match vectors of n >= 2 rank.
1165 ///
1166 /// Example:
1167 /// ```
1168 /// vector.fma %a, %a, %a : vector<8xf32>
1169 /// ```
1170 /// is converted to:
1171 /// ```
1172 /// llvm.intr.fmuladd %va, %va, %va:
1173 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1174 /// -> !llvm."<8 x f32>">
1175 /// ```
1176 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1177 public:
1179 
1180  LogicalResult
1181  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1182  ConversionPatternRewriter &rewriter) const override {
1183  VectorType vType = fmaOp.getVectorType();
1184  if (vType.getRank() > 1)
1185  return failure();
1186 
1187  rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1188  fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1189  return success();
1190  }
1191 };
1192 
1193 class VectorInsertOpConversion
1194  : public ConvertOpToLLVMPattern<vector::InsertOp> {
1195 public:
1197 
1198  LogicalResult
1199  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1200  ConversionPatternRewriter &rewriter) const override {
1201  auto loc = insertOp->getLoc();
1202  auto destVectorType = insertOp.getDestVectorType();
1203  auto llvmResultType = typeConverter->convertType(destVectorType);
1204  // Bail if result type cannot be lowered.
1205  if (!llvmResultType)
1206  return failure();
1207 
1209  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1210 
1211  // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1212  // its explanatory comment about how N-D vectors are converted as nested
1213  // aggregates (llvm.array's) of 1D vectors.
1214  //
1215  // The innermost dimension of the destination vector, when converted to a
1216  // nested aggregate form, will always be a 1D vector.
1217  //
1218  // * If the insertion is happening into the innermost dimension of the
1219  // destination vector:
1220  // - If the destination is a nested aggregate, extract a 1D vector out of
1221  // the aggregate. This can be done using llvm.extractvalue. The
1222  // destination is now guaranteed to be a 1D vector, to which we are
1223  // inserting.
1224  // - Do the insertion into the 1D destination vector, and make the result
1225  // the new source nested aggregate. This can be done using
1226  // llvm.insertelement.
1227  // * Insert the source nested aggregate into the destination nested
1228  // aggregate.
1229 
1230  // Determine if we need to extract/insert a 1D vector out of the aggregate.
1231  bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1232  // Determine if we need to insert a scalar into the 1D vector.
1233  bool insertIntoInnermostDim =
1234  static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1235 
1236  ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1237  positionVec.begin(),
1238  insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1239  OpFoldResult positionOfScalarWithin1DVector;
1240  if (destVectorType.getRank() == 0) {
1241  // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1242  // need to create a 0 here as the position into the 1D vector.
1243  Type idxType = typeConverter->convertType(rewriter.getIndexType());
1244  positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1245  } else if (insertIntoInnermostDim) {
1246  positionOfScalarWithin1DVector = positionVec.back();
1247  }
1248 
1249  // We are going to mutate this 1D vector until it is either the final
1250  // result (in the non-aggregate case) or the value that needs to be
1251  // inserted into the aggregate result.
1252  Value sourceAggregate = adaptor.getValueToStore();
1253  if (insertIntoInnermostDim) {
1254  // Scalar-into-1D-vector case, so we know we will have to create a
1255  // InsertElementOp. The question is into what destination.
1256  if (isNestedAggregate) {
1257  // Aggregate case: the destination for the InsertElementOp needs to be
1258  // extracted from the aggregate.
1259  if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1260  llvm::IsaPred<Attribute>)) {
1261  // llvm.extractvalue does not support dynamic dimensions.
1262  return failure();
1263  }
1264  sourceAggregate = LLVM::ExtractValueOp::create(
1265  rewriter, loc, adaptor.getDest(),
1266  getAsIntegers(positionOf1DVectorWithinAggregate));
1267  } else {
1268  // No-aggregate case. The destination for the InsertElementOp is just
1269  // the insertOp's destination.
1270  sourceAggregate = adaptor.getDest();
1271  }
1272  // Insert the scalar into the 1D vector.
1273  sourceAggregate = LLVM::InsertElementOp::create(
1274  rewriter, loc, sourceAggregate.getType(), sourceAggregate,
1275  adaptor.getValueToStore(),
1276  getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1277  }
1278 
1279  Value result = sourceAggregate;
1280  if (isNestedAggregate) {
1281  result = LLVM::InsertValueOp::create(
1282  rewriter, loc, adaptor.getDest(), sourceAggregate,
1283  getAsIntegers(positionOf1DVectorWithinAggregate));
1284  }
1285 
1286  rewriter.replaceOp(insertOp, result);
1287  return success();
1288  }
1289 };
1290 
1291 /// Lower vector.scalable.insert ops to LLVM vector.insert
1292 struct VectorScalableInsertOpLowering
1293  : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1294  using ConvertOpToLLVMPattern<
1295  vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1296 
1297  LogicalResult
1298  matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1299  ConversionPatternRewriter &rewriter) const override {
1300  rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1301  insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1302  return success();
1303  }
1304 };
1305 
1306 /// Lower vector.scalable.extract ops to LLVM vector.extract
1307 struct VectorScalableExtractOpLowering
1308  : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1309  using ConvertOpToLLVMPattern<
1310  vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1311 
1312  LogicalResult
1313  matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1314  ConversionPatternRewriter &rewriter) const override {
1315  rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1316  extOp, typeConverter->convertType(extOp.getResultVectorType()),
1317  adaptor.getSource(), adaptor.getPos());
1318  return success();
1319  }
1320 };
1321 
1322 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1323 ///
1324 /// Example:
1325 /// ```
1326 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1327 /// ```
1328 /// is rewritten into:
1329 /// ```
1330 /// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
1331 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1332 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1333 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1334 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1335 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1336 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1337 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1338 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1339 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1340 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1341 /// // %r3 holds the final value.
1342 /// ```
1343 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1344 public:
1345  using Base::Base;
1346 
1347  void initialize() {
1348  // This pattern recursively unpacks one dimension at a time. The recursion
1349  // bounded as the rank is strictly decreasing.
1350  setHasBoundedRewriteRecursion();
1351  }
1352 
1353  LogicalResult matchAndRewrite(FMAOp op,
1354  PatternRewriter &rewriter) const override {
1355  auto vType = op.getVectorType();
1356  if (vType.getRank() < 2)
1357  return failure();
1358 
1359  auto loc = op.getLoc();
1360  auto elemType = vType.getElementType();
1361  Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1362  rewriter.getZeroAttr(elemType));
1363  Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1364  for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1365  Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1366  Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1367  Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1368  Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1369  desc = InsertOp::create(rewriter, loc, fma, desc, i);
1370  }
1371  rewriter.replaceOp(op, desc);
1372  return success();
1373  }
1374 };
1375 
1376 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1377 /// static layout.
1378 static std::optional<SmallVector<int64_t, 4>>
1379 computeContiguousStrides(MemRefType memRefType) {
1380  int64_t offset;
1381  SmallVector<int64_t, 4> strides;
1382  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1383  return std::nullopt;
1384  if (!strides.empty() && strides.back() != 1)
1385  return std::nullopt;
1386  // If no layout or identity layout, this is contiguous by definition.
1387  if (memRefType.getLayout().isIdentity())
1388  return strides;
1389 
1390  // Otherwise, we must determine contiguity form shapes. This can only ever
1391  // work in static cases because MemRefType is underspecified to represent
1392  // contiguous dynamic shapes in other ways than with just empty/identity
1393  // layout.
1394  auto sizes = memRefType.getShape();
1395  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1396  if (ShapedType::isDynamic(sizes[index + 1]) ||
1397  ShapedType::isDynamic(strides[index]) ||
1398  ShapedType::isDynamic(strides[index + 1]))
1399  return std::nullopt;
1400  if (strides[index] != strides[index + 1] * sizes[index + 1])
1401  return std::nullopt;
1402  }
1403  return strides;
1404 }
1405 
1406 class VectorTypeCastOpConversion
1407  : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1408 public:
1410 
1411  LogicalResult
1412  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1413  ConversionPatternRewriter &rewriter) const override {
1414  auto loc = castOp->getLoc();
1415  MemRefType sourceMemRefType =
1416  cast<MemRefType>(castOp.getOperand().getType());
1417  MemRefType targetMemRefType = castOp.getType();
1418 
1419  // Only static shape casts supported atm.
1420  if (!sourceMemRefType.hasStaticShape() ||
1421  !targetMemRefType.hasStaticShape())
1422  return failure();
1423 
1424  auto llvmSourceDescriptorTy =
1425  dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1426  if (!llvmSourceDescriptorTy)
1427  return failure();
1428  MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1429 
1430  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1431  typeConverter->convertType(targetMemRefType));
1432  if (!llvmTargetDescriptorTy)
1433  return failure();
1434 
1435  // Only contiguous source buffers supported atm.
1436  auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1437  if (!sourceStrides)
1438  return failure();
1439  auto targetStrides = computeContiguousStrides(targetMemRefType);
1440  if (!targetStrides)
1441  return failure();
1442  // Only support static strides for now, regardless of contiguity.
1443  if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1444  return failure();
1445 
1446  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1447 
1448  // Create descriptor.
1449  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1450  // Set allocated ptr.
1451  Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1452  desc.setAllocatedPtr(rewriter, loc, allocated);
1453 
1454  // Set aligned ptr.
1455  Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1456  desc.setAlignedPtr(rewriter, loc, ptr);
1457  // Fill offset 0.
1458  auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1459  auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1460  desc.setOffset(rewriter, loc, zero);
1461 
1462  // Fill size and stride descriptors in memref.
1463  for (const auto &indexedSize :
1464  llvm::enumerate(targetMemRefType.getShape())) {
1465  int64_t index = indexedSize.index();
1466  auto sizeAttr =
1467  rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1468  auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1469  desc.setSize(rewriter, loc, index, size);
1470  auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1471  (*targetStrides)[index]);
1472  auto stride =
1473  LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1474  desc.setStride(rewriter, loc, index, stride);
1475  }
1476 
1477  rewriter.replaceOp(castOp, {desc});
1478  return success();
1479  }
1480 };
1481 
1482 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1483 /// Non-scalable versions of this operation are handled in Vector Transforms.
1484 class VectorCreateMaskOpConversion
1485  : public OpConversionPattern<vector::CreateMaskOp> {
1486 public:
1487  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1488  bool enableIndexOpt)
1489  : OpConversionPattern<vector::CreateMaskOp>(context),
1490  force32BitVectorIndices(enableIndexOpt) {}
1491 
1492  LogicalResult
1493  matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1494  ConversionPatternRewriter &rewriter) const override {
1495  auto dstType = op.getType();
1496  if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1497  return failure();
1498  IntegerType idxType =
1499  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1500  auto loc = op->getLoc();
1501  Value indices = LLVM::StepVectorOp::create(
1502  rewriter, loc,
1503  LLVM::getVectorType(idxType, dstType.getShape()[0],
1504  /*isScalable=*/true));
1505  auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1506  adaptor.getOperands()[0]);
1507  Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1508  Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1509  indices, bounds);
1510  rewriter.replaceOp(op, comp);
1511  return success();
1512  }
1513 
1514 private:
1515  const bool force32BitVectorIndices;
1516 };
1517 
1518 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1519  SymbolTableCollection *symbolTables = nullptr;
1520 
1521 public:
1522  explicit VectorPrintOpConversion(
1523  const LLVMTypeConverter &typeConverter,
1524  SymbolTableCollection *symbolTables = nullptr)
1525  : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1526  symbolTables(symbolTables) {}
1527 
1528  // Lowering implementation that relies on a small runtime support library,
1529  // which only needs to provide a few printing methods (single value for all
1530  // data types, opening/closing bracket, comma, newline). The lowering splits
1531  // the vector into elementary printing operations. The advantage of this
1532  // approach is that the library can remain unaware of all low-level
1533  // implementation details of vectors while still supporting output of any
1534  // shaped and dimensioned vector.
1535  //
1536  // Note: This lowering only handles scalars, n-D vectors are broken into
1537  // printing scalars in loops in VectorToSCF.
1538  //
1539  // TODO: rely solely on libc in future? something else?
1540  //
1541  LogicalResult
1542  matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1543  ConversionPatternRewriter &rewriter) const override {
1544  auto parent = printOp->getParentOfType<ModuleOp>();
1545  if (!parent)
1546  return failure();
1547 
1548  auto loc = printOp->getLoc();
1549 
1550  if (auto value = adaptor.getSource()) {
1551  Type printType = printOp.getPrintType();
1552  if (isa<VectorType>(printType)) {
1553  // Vectors should be broken into elementary print ops in VectorToSCF.
1554  return failure();
1555  }
1556  if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1557  return failure();
1558  }
1559 
1560  auto punct = printOp.getPunctuation();
1561  if (auto stringLiteral = printOp.getStringLiteral()) {
1562  auto createResult =
1563  LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1564  *stringLiteral, *getTypeConverter(),
1565  /*addNewline=*/false);
1566  if (createResult.failed())
1567  return failure();
1568 
1569  } else if (punct != PrintPunctuation::NoPunctuation) {
1570  FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1571  switch (punct) {
1572  case PrintPunctuation::Close:
1573  return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
1574  symbolTables);
1575  case PrintPunctuation::Open:
1576  return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
1577  symbolTables);
1578  case PrintPunctuation::Comma:
1579  return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
1580  symbolTables);
1581  case PrintPunctuation::NewLine:
1582  return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
1583  symbolTables);
1584  default:
1585  llvm_unreachable("unexpected punctuation");
1586  }
1587  }();
1588  if (failed(op))
1589  return failure();
1590  emitCall(rewriter, printOp->getLoc(), op.value());
1591  }
1592 
1593  rewriter.eraseOp(printOp);
1594  return success();
1595  }
1596 
1597 private:
1598  enum class PrintConversion {
1599  // clang-format off
1600  None,
1601  ZeroExt64,
1602  SignExt64,
1603  Bitcast16
1604  // clang-format on
1605  };
1606 
1607  LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1608  ModuleOp parent, Location loc, Type printType,
1609  Value value) const {
1610  if (typeConverter->convertType(printType) == nullptr)
1611  return failure();
1612 
1613  // Make sure element type has runtime support.
1614  PrintConversion conversion = PrintConversion::None;
1615  FailureOr<Operation *> printer;
1616  if (printType.isF32()) {
1617  printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
1618  } else if (printType.isF64()) {
1619  printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
1620  } else if (printType.isF16()) {
1621  conversion = PrintConversion::Bitcast16; // bits!
1622  printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
1623  } else if (printType.isBF16()) {
1624  conversion = PrintConversion::Bitcast16; // bits!
1625  printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
1626  } else if (printType.isIndex()) {
1627  printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1628  } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1629  // Integers need a zero or sign extension on the operand
1630  // (depending on the source type) as well as a signed or
1631  // unsigned print method. Up to 64-bit is supported.
1632  unsigned width = intTy.getWidth();
1633  if (intTy.isUnsigned()) {
1634  if (width <= 64) {
1635  if (width < 64)
1636  conversion = PrintConversion::ZeroExt64;
1637  printer =
1638  LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1639  } else {
1640  return failure();
1641  }
1642  } else {
1643  assert(intTy.isSignless() || intTy.isSigned());
1644  if (width <= 64) {
1645  // Note that we *always* zero extend booleans (1-bit integers),
1646  // so that true/false is printed as 1/0 rather than -1/0.
1647  if (width == 1)
1648  conversion = PrintConversion::ZeroExt64;
1649  else if (width < 64)
1650  conversion = PrintConversion::SignExt64;
1651  printer =
1652  LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
1653  } else {
1654  return failure();
1655  }
1656  }
1657  } else {
1658  return failure();
1659  }
1660  if (failed(printer))
1661  return failure();
1662 
1663  switch (conversion) {
1664  case PrintConversion::ZeroExt64:
1665  value = arith::ExtUIOp::create(
1666  rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1667  break;
1668  case PrintConversion::SignExt64:
1669  value = arith::ExtSIOp::create(
1670  rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1671  break;
1672  case PrintConversion::Bitcast16:
1673  value = LLVM::BitcastOp::create(
1674  rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1675  break;
1676  case PrintConversion::None:
1677  break;
1678  }
1679  emitCall(rewriter, loc, printer.value(), value);
1680  return success();
1681  }
1682 
1683  // Helper to emit a call.
1684  static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1685  Operation *ref, ValueRange params = ValueRange()) {
1686  LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref),
1687  params);
1688  }
1689 };
1690 
1691 /// A broadcast of a scalar is lowered to an insertelement + a shufflevector
1692 /// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1693 /// pattern, the higher rank cases are handled by another pattern.
1694 struct VectorBroadcastScalarToLowRankLowering
1695  : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1697 
1698  LogicalResult
1699  matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
1700  ConversionPatternRewriter &rewriter) const override {
1701  if (isa<VectorType>(broadcast.getSourceType()))
1702  return rewriter.notifyMatchFailure(
1703  broadcast, "broadcast from vector type not handled");
1704 
1705  VectorType resultType = broadcast.getType();
1706  if (resultType.getRank() > 1)
1707  return rewriter.notifyMatchFailure(broadcast,
1708  "broadcast to 2+-d handled elsewhere");
1709 
1710  // First insert it into a poison vector so we can shuffle it.
1711  auto vectorType = typeConverter->convertType(broadcast.getType());
1712  Value poison =
1713  LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType);
1714  auto zero = LLVM::ConstantOp::create(
1715  rewriter, broadcast.getLoc(),
1716  typeConverter->convertType(rewriter.getIntegerType(32)),
1717  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1718 
1719  // For 0-d vector, we simply do `insertelement`.
1720  if (resultType.getRank() == 0) {
1721  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1722  broadcast, vectorType, poison, adaptor.getSource(), zero);
1723  return success();
1724  }
1725 
1726  // For 1-d vector, we additionally do a `vectorshuffle`.
1727  auto v =
1728  LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
1729  poison, adaptor.getSource(), zero);
1730 
1731  int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
1732  SmallVector<int32_t> zeroValues(width, 0);
1733 
1734  // Shuffle the value across the desired number of elements.
1735  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
1736  zeroValues);
1737  return success();
1738  }
1739 };
1740 
1741 /// The broadcast of a scalar is lowered to an insertelement + a shufflevector
1742 /// operation. Only broadcasts to 2+-d vector result types are lowered by this
1743 /// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1744 /// are not converted to LLVM, only broadcasts from scalars are.
1745 struct VectorBroadcastScalarToNdLowering
1746  : public ConvertOpToLLVMPattern<BroadcastOp> {
1748 
1749  LogicalResult
1750  matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
1751  ConversionPatternRewriter &rewriter) const override {
1752  if (isa<VectorType>(broadcast.getSourceType()))
1753  return rewriter.notifyMatchFailure(
1754  broadcast, "broadcast from vector type not handled");
1755 
1756  VectorType resultType = broadcast.getType();
1757  if (resultType.getRank() <= 1)
1758  return rewriter.notifyMatchFailure(
1759  broadcast, "broadcast to 1-d or 0-d handled elsewhere");
1760 
1761  // First insert it into an undef vector so we can shuffle it.
1762  auto loc = broadcast.getLoc();
1763  auto vectorTypeInfo =
1764  LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1765  auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1766  auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1767  if (!llvmNDVectorTy || !llvm1DVectorTy)
1768  return failure();
1769 
1770  // Construct returned value.
1771  Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1772 
1773  // Construct a 1-D vector with the broadcasted value that we insert in all
1774  // the places within the returned descriptor.
1775  Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1776  auto zero = LLVM::ConstantOp::create(
1777  rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1778  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1779  Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1780  vdesc, adaptor.getSource(), zero);
1781 
1782  // Shuffle the value across the desired number of elements.
1783  int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1784  SmallVector<int32_t> zeroValues(width, 0);
1785  v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1786 
1787  // Iterate of linear index, convert to coords space and insert broadcasted
1788  // 1-D vector in each position.
1789  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1790  desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1791  });
1792  rewriter.replaceOp(broadcast, desc);
1793  return success();
1794  }
1795 };
1796 
1797 /// Conversion pattern for a `vector.interleave`.
1798 /// This supports fixed-sized vectors and scalable vectors.
1799 struct VectorInterleaveOpLowering
1800  : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1802 
1803  LogicalResult
1804  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1805  ConversionPatternRewriter &rewriter) const override {
1806  VectorType resultType = interleaveOp.getResultVectorType();
1807  // n-D interleaves should have been lowered already.
1808  if (resultType.getRank() != 1)
1809  return rewriter.notifyMatchFailure(interleaveOp,
1810  "InterleaveOp not rank 1");
1811  // If the result is rank 1, then this directly maps to LLVM.
1812  if (resultType.isScalable()) {
1813  rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1814  interleaveOp, typeConverter->convertType(resultType),
1815  adaptor.getLhs(), adaptor.getRhs());
1816  return success();
1817  }
1818  // Lower fixed-size interleaves to a shufflevector. While the
1819  // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1820  // langref still recommends fixed-vectors use shufflevector, see:
1821  // https://llvm.org/docs/LangRef.html#id876.
1822  int64_t resultVectorSize = resultType.getNumElements();
1823  SmallVector<int32_t> interleaveShuffleMask;
1824  interleaveShuffleMask.reserve(resultVectorSize);
1825  for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1826  interleaveShuffleMask.push_back(i);
1827  interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1828  }
1829  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1830  interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1831  interleaveShuffleMask);
1832  return success();
1833  }
1834 };
1835 
1836 /// Conversion pattern for a `vector.deinterleave`.
1837 /// This supports fixed-sized vectors and scalable vectors.
1838 struct VectorDeinterleaveOpLowering
1839  : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1841 
1842  LogicalResult
1843  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1844  ConversionPatternRewriter &rewriter) const override {
1845  VectorType resultType = deinterleaveOp.getResultVectorType();
1846  VectorType sourceType = deinterleaveOp.getSourceVectorType();
1847  auto loc = deinterleaveOp.getLoc();
1848 
1849  // Note: n-D deinterleave operations should be lowered to the 1-D before
1850  // converting to LLVM.
1851  if (resultType.getRank() != 1)
1852  return rewriter.notifyMatchFailure(deinterleaveOp,
1853  "DeinterleaveOp not rank 1");
1854 
1855  if (resultType.isScalable()) {
1856  const auto *llvmTypeConverter = this->getTypeConverter();
1857  auto deinterleaveResults = deinterleaveOp.getResultTypes();
1858  auto packedOpResults =
1859  llvmTypeConverter->packOperationResults(deinterleaveResults);
1860  auto intrinsic = LLVM::vector_deinterleave2::create(
1861  rewriter, loc, packedOpResults, adaptor.getSource());
1862 
1863  auto evenResult = LLVM::ExtractValueOp::create(
1864  rewriter, loc, intrinsic->getResult(0), 0);
1865  auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1866  intrinsic->getResult(0), 1);
1867 
1868  rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1869  return success();
1870  }
1871  // Lower fixed-size deinterleave to two shufflevectors. While the
1872  // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1873  // langref still recommends fixed-vectors use shufflevector, see:
1874  // https://llvm.org/docs/LangRef.html#id889.
1875  int64_t resultVectorSize = resultType.getNumElements();
1876  SmallVector<int32_t> evenShuffleMask;
1877  SmallVector<int32_t> oddShuffleMask;
1878 
1879  evenShuffleMask.reserve(resultVectorSize);
1880  oddShuffleMask.reserve(resultVectorSize);
1881 
1882  for (int i = 0; i < sourceType.getNumElements(); ++i) {
1883  if (i % 2 == 0)
1884  evenShuffleMask.push_back(i);
1885  else
1886  oddShuffleMask.push_back(i);
1887  }
1888 
1889  auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1890  auto evenShuffle = LLVM::ShuffleVectorOp::create(
1891  rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1892  auto oddShuffle = LLVM::ShuffleVectorOp::create(
1893  rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1894 
1895  rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1896  return success();
1897  }
1898 };
1899 
1900 /// Conversion pattern for a `vector.from_elements`.
1901 struct VectorFromElementsLowering
1902  : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1904 
1905  LogicalResult
1906  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1907  ConversionPatternRewriter &rewriter) const override {
1908  Location loc = fromElementsOp.getLoc();
1909  VectorType vectorType = fromElementsOp.getType();
1910  // Only support 1-D vectors. Multi-dimensional vectors should have been
1911  // transformed to 1-D vectors by the vector-to-vector transformations before
1912  // this.
1913  if (vectorType.getRank() > 1)
1914  return rewriter.notifyMatchFailure(fromElementsOp,
1915  "rank > 1 vectors are not supported");
1916  Type llvmType = typeConverter->convertType(vectorType);
1917  Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1918  Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1919  for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1920  auto constIdx =
1921  LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1922  result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1923  val, constIdx);
1924  }
1925  rewriter.replaceOp(fromElementsOp, result);
1926  return success();
1927  }
1928 };
1929 
1930 /// Conversion pattern for a `vector.to_elements`.
1931 struct VectorToElementsLowering
1932  : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1934 
1935  LogicalResult
1936  matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1937  ConversionPatternRewriter &rewriter) const override {
1938  Location loc = toElementsOp.getLoc();
1939  auto idxType = typeConverter->convertType(rewriter.getIndexType());
1940  Value source = adaptor.getSource();
1941 
1942  SmallVector<Value> results(toElementsOp->getNumResults());
1943  for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1944  // Create an extractelement operation only for results that are not dead.
1945  if (element.use_empty())
1946  continue;
1947 
1948  auto constIdx = LLVM::ConstantOp::create(
1949  rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1950  auto llvmType = typeConverter->convertType(element.getType());
1951 
1952  Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1953  source, constIdx);
1954  results[idx] = result;
1955  }
1956 
1957  rewriter.replaceOp(toElementsOp, results);
1958  return success();
1959  }
1960 };
1961 
1962 /// Conversion pattern for vector.step.
1963 struct VectorScalableStepOpLowering
1964  : public ConvertOpToLLVMPattern<vector::StepOp> {
1966 
1967  LogicalResult
1968  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1969  ConversionPatternRewriter &rewriter) const override {
1970  auto resultType = cast<VectorType>(stepOp.getType());
1971  if (!resultType.isScalable()) {
1972  return failure();
1973  }
1974  Type llvmType = typeConverter->convertType(stepOp.getType());
1975  rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1976  return success();
1977  }
1978 };
1979 
1980 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
1981 /// semantics to:
1982 /// ```
1983 /// %flattened_a = vector.shape_cast %a
1984 /// %flattened_b = vector.shape_cast %b
1985 /// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
1986 /// %d = vector.shape_cast %%flattened_d
1987 /// %e = add %c, %d
1988 /// ```
1989 /// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
1990 class ContractionOpToMatmulOpLowering
1991  : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
1992 public:
1993  using MaskableOpRewritePattern::MaskableOpRewritePattern;
1994 
1995  ContractionOpToMatmulOpLowering(MLIRContext *context,
1996  PatternBenefit benefit = 100)
1997  : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
1998 
1999  FailureOr<Value>
2000  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2001  PatternRewriter &rewriter) const override;
2002 };
2003 
2004 /// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul
2005 /// semantics directly into `llvm.intr.matrix.multiply`:
2006 /// BEFORE:
2007 /// ```mlir
2008 /// %res = vector.contract #matmat_trait %lhs, %rhs, %acc
2009 /// : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
2010 /// ```
2011 ///
2012 /// AFTER:
2013 /// ```mlir
2014 /// %lhs = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
2015 /// %rhs = vector.shape_cast %arg1 : vector<4x3xf32> to vector<12xf32>
2016 /// %matmul = llvm.intr.matrix.multiply %lhs, %rhs
2017 /// %res = arith.addf %acc, %matmul : vector<2x3xf32>
2018 /// ```
2019 //
2020 /// Scalable vectors are not supported.
2021 FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2022  vector::ContractionOp op, MaskingOpInterface maskOp,
2023  PatternRewriter &rew) const {
2024  // TODO: Support vector.mask.
2025  if (maskOp)
2026  return failure();
2027 
2028  auto iteratorTypes = op.getIteratorTypes().getValue();
2029  if (!isParallelIterator(iteratorTypes[0]) ||
2030  !isParallelIterator(iteratorTypes[1]) ||
2031  !isReductionIterator(iteratorTypes[2]))
2032  return failure();
2033 
2034  Type opResType = op.getType();
2035  VectorType vecType = dyn_cast<VectorType>(opResType);
2036  if (vecType && vecType.isScalable()) {
2037  // Note - this is sufficient to reject all cases with scalable vectors.
2038  return failure();
2039  }
2040 
2041  Type elementType = op.getLhsType().getElementType();
2042  if (!elementType.isIntOrFloat())
2043  return failure();
2044 
2045  Type dstElementType = vecType ? vecType.getElementType() : opResType;
2046  if (elementType != dstElementType)
2047  return failure();
2048 
2049  // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2050  // Bail out if the contraction cannot be put in this form.
2051  MLIRContext *ctx = op.getContext();
2052  Location loc = op.getLoc();
2053  AffineExpr m, n, k;
2054  bindDims(rew.getContext(), m, n, k);
2055  // LHS must be A(m, k) or A(k, m).
2056  Value lhs = op.getLhs();
2057  auto lhsMap = op.getIndexingMapsArray()[0];
2058  if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
2059  lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0});
2060  else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
2061  return failure();
2062 
2063  // RHS must be B(k, n) or B(n, k).
2064  Value rhs = op.getRhs();
2065  auto rhsMap = op.getIndexingMapsArray()[1];
2066  if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
2067  rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0});
2068  else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
2069  return failure();
2070 
2071  // At this point lhs and rhs are in row-major.
2072  VectorType lhsType = cast<VectorType>(lhs.getType());
2073  VectorType rhsType = cast<VectorType>(rhs.getType());
2074  int64_t lhsRows = lhsType.getDimSize(0);
2075  int64_t lhsColumns = lhsType.getDimSize(1);
2076  int64_t rhsColumns = rhsType.getDimSize(1);
2077 
2078  Type flattenedLHSType =
2079  VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2080  lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
2081 
2082  Type flattenedRHSType =
2083  VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2084  rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
2085 
2086  Value mul = LLVM::MatrixMultiplyOp::create(
2087  rew, loc,
2088  VectorType::get(lhsRows * rhsColumns,
2089  cast<VectorType>(lhs.getType()).getElementType()),
2090  lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2091 
2092  mul = vector::ShapeCastOp::create(
2093  rew, loc,
2094  VectorType::get({lhsRows, rhsColumns},
2095  getElementTypeOrSelf(op.getAcc().getType())),
2096  mul);
2097 
2098  // ACC must be C(m, n) or C(n, m).
2099  auto accMap = op.getIndexingMapsArray()[2];
2100  if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
2101  mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0});
2102  else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
2103  llvm_unreachable("invalid contraction semantics");
2104 
2105  Value res = isa<IntegerType>(elementType)
2106  ? static_cast<Value>(
2107  arith::AddIOp::create(rew, loc, op.getAcc(), mul))
2108  : static_cast<Value>(
2109  arith::AddFOp::create(rew, loc, op.getAcc(), mul));
2110 
2111  return res;
2112 }
2113 
2114 /// Lowers vector.transpose directly to llvm.intr.matrix.transpose
2115 ///
2116 /// BEFORE:
2117 /// ```mlir
2118 /// %tr = vector.transpose %vec, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
2119 /// ```
2120 /// AFTER:
2121 /// ```mlir
2122 /// %vec_cs = vector.shape_cast %vec : vector<2x4xf32> to vector<8xf32>
2123 /// %tr = llvm.intr.matrix.transpose %vec_sc
2124 /// {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
2125 /// %res = vector.shape_cast %tr : vector<8xf32> to vector<4x2xf32>
2126 /// ```
2127 class TransposeOpToMatrixTransposeOpLowering
2128  : public OpRewritePattern<vector::TransposeOp> {
2129 public:
2130  using Base::Base;
2131 
2132  LogicalResult matchAndRewrite(vector::TransposeOp op,
2133  PatternRewriter &rewriter) const override {
2134  auto loc = op.getLoc();
2135 
2136  Value input = op.getVector();
2137  VectorType inputType = op.getSourceVectorType();
2138  VectorType resType = op.getResultVectorType();
2139 
2140  if (inputType.isScalable())
2141  return rewriter.notifyMatchFailure(
2142  op, "This lowering does not support scalable vectors");
2143 
2144  // Set up convenience transposition table.
2145  ArrayRef<int64_t> transp = op.getPermutation();
2146 
2147  if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2148  return failure();
2149  }
2150 
2151  Type flattenedType =
2152  VectorType::get(resType.getNumElements(), resType.getElementType());
2153  auto matrix =
2154  vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2155  auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
2156  auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
2157  Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2158  matrix, rows, columns);
2159  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
2160  return success();
2161  }
2162 };
2163 
2164 } // namespace
2165 
2168  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
2169 }
2170 
2173  patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit);
2174 }
2175 
2178  patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(),
2179  benefit);
2180 }
2181 
2182 /// Populate the given list with patterns that convert from Vector to LLVM.
2184  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2185  bool reassociateFPReductions, bool force32BitVectorIndices,
2186  bool useVectorAlignment) {
2187  // This function populates only ConversionPatterns, not RewritePatterns.
2188  MLIRContext *ctx = converter.getDialect()->getContext();
2189  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2190  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2191  patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2192  VectorLoadStoreConversion<vector::MaskedLoadOp>,
2193  VectorLoadStoreConversion<vector::StoreOp>,
2194  VectorLoadStoreConversion<vector::MaskedStoreOp>,
2195  VectorGatherOpConversion, VectorScatterOpConversion>(
2196  converter, useVectorAlignment);
2197  patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2198  VectorExtractOpConversion, VectorFMAOp1DConversion,
2199  VectorInsertOpConversion, VectorPrintOpConversion,
2200  VectorTypeCastOpConversion, VectorScaleOpConversion,
2201  VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2202  VectorBroadcastScalarToLowRankLowering,
2203  VectorBroadcastScalarToNdLowering,
2204  VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2205  MaskedReductionOpConversion, VectorInterleaveOpLowering,
2206  VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2207  VectorToElementsLowering, VectorScalableStepOpLowering>(
2208  converter);
2209 }
2210 
2211 namespace {
2212 struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2214  void loadDependentDialects(MLIRContext *context) const final {
2215  context->loadDialect<LLVM::LLVMDialect>();
2216  }
2217 
2218  /// Hook for derived dialect interface to provide conversion patterns
2219  /// and mark dialect legal for the conversion target.
2220  void populateConvertToLLVMConversionPatterns(
2221  ConversionTarget &target, LLVMTypeConverter &typeConverter,
2222  RewritePatternSet &patterns) const final {
2224  }
2225 };
2226 } // namespace
2227 
2229  DialectRegistry &registry) {
2230  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2231  dialect->addInterfaces<VectorToLLVMDialectInterface>();
2232  });
2233 }
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
@ None
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define MINUI(lhs, rhs)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:18
int64_t rows
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:300
This class helps build Operations.
Definition: Builders.h:207
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition: Pattern.cpp:478
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:867
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:155
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:150
void registerConvertVectorToLLVMInterface(DialectRegistry &registry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:346
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:119
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163