MLIR  20.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39  assert((tp.getRank() > 1) && "unlowerable vector type");
40  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41  tp.getScalableDims().take_back());
42 }
43 
44 // Helper that picks the proper sequence for inserting.
46  const LLVMTypeConverter &typeConverter, Location loc,
47  Value val1, Value val2, Type llvmType, int64_t rank,
48  int64_t pos) {
49  assert(rank > 0 && "0-D vector corner case should have been handled already");
50  if (rank == 1) {
51  auto idxType = rewriter.getIndexType();
52  auto constant = rewriter.create<LLVM::ConstantOp>(
53  loc, typeConverter.convertType(idxType),
54  rewriter.getIntegerAttr(idxType, pos));
55  return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
56  constant);
57  }
58  return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
59 }
60 
61 // Helper that picks the proper sequence for extracting.
63  const LLVMTypeConverter &typeConverter, Location loc,
64  Value val, Type llvmType, int64_t rank, int64_t pos) {
65  if (rank <= 1) {
66  auto idxType = rewriter.getIndexType();
67  auto constant = rewriter.create<LLVM::ConstantOp>(
68  loc, typeConverter.convertType(idxType),
69  rewriter.getIntegerAttr(idxType, pos));
70  return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
71  constant);
72  }
73  return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
74 }
75 
76 // Helper that returns data layout alignment of a memref.
77 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
78  MemRefType memrefType, unsigned &align) {
79  Type elementTy = typeConverter.convertType(memrefType.getElementType());
80  if (!elementTy)
81  return failure();
82 
83  // TODO: this should use the MLIR data layout when it becomes available and
84  // stop depending on translation.
85  llvm::LLVMContext llvmContext;
86  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
87  .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
88  return success();
89 }
90 
91 // Check if the last stride is non-unit and has a valid memory space.
92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
93  const LLVMTypeConverter &converter) {
94  if (!isLastMemrefDimUnitStride(memRefType))
95  return failure();
96  if (failed(converter.getMemRefAddressSpace(memRefType)))
97  return failure();
98  return success();
99 }
100 
101 // Add an index vector component to a base pointer.
103  const LLVMTypeConverter &typeConverter,
104  MemRefType memRefType, Value llvmMemref, Value base,
105  Value index, VectorType vectorType) {
106  assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107  "unsupported memref type");
108  assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
109  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
110  auto ptrsType =
111  LLVM::getVectorType(pType, vectorType.getDimSize(0),
112  /*isScalable=*/vectorType.getScalableDims()[0]);
113  return rewriter.create<LLVM::GEPOp>(
114  loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
115  base, index);
116 }
117 
118 /// Convert `foldResult` into a Value. Integer attribute is converted to
119 /// an LLVM constant op.
120 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
121  OpFoldResult foldResult) {
122  if (auto attr = foldResult.dyn_cast<Attribute>()) {
123  auto intAttr = cast<IntegerAttr>(attr);
124  return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
125  }
126 
127  return foldResult.get<Value>();
128 }
129 
130 namespace {
131 
132 /// Trivial Vector to LLVM conversions
133 using VectorScaleOpConversion =
135 
136 /// Conversion pattern for a vector.bitcast.
137 class VectorBitCastOpConversion
138  : public ConvertOpToLLVMPattern<vector::BitCastOp> {
139 public:
141 
142  LogicalResult
143  matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
144  ConversionPatternRewriter &rewriter) const override {
145  // Only 0-D and 1-D vectors can be lowered to LLVM.
146  VectorType resultTy = bitCastOp.getResultVectorType();
147  if (resultTy.getRank() > 1)
148  return failure();
149  Type newResultTy = typeConverter->convertType(resultTy);
150  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
151  adaptor.getOperands()[0]);
152  return success();
153  }
154 };
155 
156 /// Conversion pattern for a vector.matrix_multiply.
157 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
158 class VectorMatmulOpConversion
159  : public ConvertOpToLLVMPattern<vector::MatmulOp> {
160 public:
162 
163  LogicalResult
164  matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
165  ConversionPatternRewriter &rewriter) const override {
166  rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
167  matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
168  adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
169  matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
170  return success();
171  }
172 };
173 
174 /// Conversion pattern for a vector.flat_transpose.
175 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
176 class VectorFlatTransposeOpConversion
177  : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
178 public:
180 
181  LogicalResult
182  matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
183  ConversionPatternRewriter &rewriter) const override {
184  rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
185  transOp, typeConverter->convertType(transOp.getRes().getType()),
186  adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
187  return success();
188  }
189 };
190 
191 /// Overloaded utility that replaces a vector.load, vector.store,
192 /// vector.maskedload and vector.maskedstore with their respective LLVM
193 /// couterparts.
194 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
195  vector::LoadOpAdaptor adaptor,
196  VectorType vectorTy, Value ptr, unsigned align,
197  ConversionPatternRewriter &rewriter) {
198  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
199  /*volatile_=*/false,
200  loadOp.getNontemporal());
201 }
202 
203 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
204  vector::MaskedLoadOpAdaptor adaptor,
205  VectorType vectorTy, Value ptr, unsigned align,
206  ConversionPatternRewriter &rewriter) {
207  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
208  loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
209 }
210 
211 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
212  vector::StoreOpAdaptor adaptor,
213  VectorType vectorTy, Value ptr, unsigned align,
214  ConversionPatternRewriter &rewriter) {
215  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
216  ptr, align, /*volatile_=*/false,
217  storeOp.getNontemporal());
218 }
219 
220 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
221  vector::MaskedStoreOpAdaptor adaptor,
222  VectorType vectorTy, Value ptr, unsigned align,
223  ConversionPatternRewriter &rewriter) {
224  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
225  storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
226 }
227 
228 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
229 /// vector.maskedstore.
230 template <class LoadOrStoreOp>
231 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
232 public:
234 
235  LogicalResult
236  matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
237  typename LoadOrStoreOp::Adaptor adaptor,
238  ConversionPatternRewriter &rewriter) const override {
239  // Only 1-D vectors can be lowered to LLVM.
240  VectorType vectorTy = loadOrStoreOp.getVectorType();
241  if (vectorTy.getRank() > 1)
242  return failure();
243 
244  auto loc = loadOrStoreOp->getLoc();
245  MemRefType memRefTy = loadOrStoreOp.getMemRefType();
246 
247  // Resolve alignment.
248  unsigned align;
249  if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
250  return failure();
251 
252  // Resolve address.
253  auto vtype = cast<VectorType>(
254  this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
255  Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
256  adaptor.getIndices(), rewriter);
257  replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
258  rewriter);
259  return success();
260  }
261 };
262 
263 /// Conversion pattern for a vector.gather.
264 class VectorGatherOpConversion
265  : public ConvertOpToLLVMPattern<vector::GatherOp> {
266 public:
268 
269  LogicalResult
270  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
271  ConversionPatternRewriter &rewriter) const override {
272  MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
273  assert(memRefType && "The base should be bufferized");
274 
275  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
276  return failure();
277 
278  auto loc = gather->getLoc();
279 
280  // Resolve alignment.
281  unsigned align;
282  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
283  return failure();
284 
285  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
286  adaptor.getIndices(), rewriter);
287  Value base = adaptor.getBase();
288 
289  auto llvmNDVectorTy = adaptor.getIndexVec().getType();
290  // Handle the simple case of 1-D vector.
291  if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
292  auto vType = gather.getVectorType();
293  // Resolve address.
294  Value ptrs =
295  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
296  base, ptr, adaptor.getIndexVec(), vType);
297  // Replace with the gather intrinsic.
298  rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
299  gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
300  adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
301  return success();
302  }
303 
304  const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
305  auto callback = [align, memRefType, base, ptr, loc, &rewriter,
306  &typeConverter](Type llvm1DVectorTy,
307  ValueRange vectorOperands) {
308  // Resolve address.
309  Value ptrs = getIndexedPtrs(
310  rewriter, loc, typeConverter, memRefType, base, ptr,
311  /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
312  // Create the gather intrinsic.
313  return rewriter.create<LLVM::masked_gather>(
314  loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
315  /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
316  };
317  SmallVector<Value> vectorOperands = {
318  adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
320  gather, vectorOperands, *getTypeConverter(), callback, rewriter);
321  }
322 };
323 
324 /// Conversion pattern for a vector.scatter.
325 class VectorScatterOpConversion
326  : public ConvertOpToLLVMPattern<vector::ScatterOp> {
327 public:
329 
330  LogicalResult
331  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
332  ConversionPatternRewriter &rewriter) const override {
333  auto loc = scatter->getLoc();
334  MemRefType memRefType = scatter.getMemRefType();
335 
336  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
337  return failure();
338 
339  // Resolve alignment.
340  unsigned align;
341  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
342  return failure();
343 
344  // Resolve address.
345  VectorType vType = scatter.getVectorType();
346  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
347  adaptor.getIndices(), rewriter);
348  Value ptrs =
349  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
350  adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
351 
352  // Replace with the scatter intrinsic.
353  rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
354  scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
355  rewriter.getI32IntegerAttr(align));
356  return success();
357  }
358 };
359 
360 /// Conversion pattern for a vector.expandload.
361 class VectorExpandLoadOpConversion
362  : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
363 public:
365 
366  LogicalResult
367  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
368  ConversionPatternRewriter &rewriter) const override {
369  auto loc = expand->getLoc();
370  MemRefType memRefType = expand.getMemRefType();
371 
372  // Resolve address.
373  auto vtype = typeConverter->convertType(expand.getVectorType());
374  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
375  adaptor.getIndices(), rewriter);
376 
377  rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
378  expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
379  return success();
380  }
381 };
382 
383 /// Conversion pattern for a vector.compressstore.
384 class VectorCompressStoreOpConversion
385  : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
386 public:
388 
389  LogicalResult
390  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
391  ConversionPatternRewriter &rewriter) const override {
392  auto loc = compress->getLoc();
393  MemRefType memRefType = compress.getMemRefType();
394 
395  // Resolve address.
396  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397  adaptor.getIndices(), rewriter);
398 
399  rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
400  compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
401  return success();
402  }
403 };
404 
405 /// Reduction neutral classes for overloading.
406 class ReductionNeutralZero {};
407 class ReductionNeutralIntOne {};
408 class ReductionNeutralFPOne {};
409 class ReductionNeutralAllOnes {};
410 class ReductionNeutralSIntMin {};
411 class ReductionNeutralUIntMin {};
412 class ReductionNeutralSIntMax {};
413 class ReductionNeutralUIntMax {};
414 class ReductionNeutralFPMin {};
415 class ReductionNeutralFPMax {};
416 
417 /// Create the reduction neutral zero value.
418 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
419  ConversionPatternRewriter &rewriter,
420  Location loc, Type llvmType) {
421  return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
422  rewriter.getZeroAttr(llvmType));
423 }
424 
425 /// Create the reduction neutral integer one value.
426 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
427  ConversionPatternRewriter &rewriter,
428  Location loc, Type llvmType) {
429  return rewriter.create<LLVM::ConstantOp>(
430  loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
431 }
432 
433 /// Create the reduction neutral fp one value.
434 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
435  ConversionPatternRewriter &rewriter,
436  Location loc, Type llvmType) {
437  return rewriter.create<LLVM::ConstantOp>(
438  loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
439 }
440 
441 /// Create the reduction neutral all-ones value.
442 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
443  ConversionPatternRewriter &rewriter,
444  Location loc, Type llvmType) {
445  return rewriter.create<LLVM::ConstantOp>(
446  loc, llvmType,
447  rewriter.getIntegerAttr(
448  llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
449 }
450 
451 /// Create the reduction neutral signed int minimum value.
452 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
453  ConversionPatternRewriter &rewriter,
454  Location loc, Type llvmType) {
455  return rewriter.create<LLVM::ConstantOp>(
456  loc, llvmType,
457  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
458  llvmType.getIntOrFloatBitWidth())));
459 }
460 
461 /// Create the reduction neutral unsigned int minimum value.
462 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
463  ConversionPatternRewriter &rewriter,
464  Location loc, Type llvmType) {
465  return rewriter.create<LLVM::ConstantOp>(
466  loc, llvmType,
467  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
468  llvmType.getIntOrFloatBitWidth())));
469 }
470 
471 /// Create the reduction neutral signed int maximum value.
472 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
473  ConversionPatternRewriter &rewriter,
474  Location loc, Type llvmType) {
475  return rewriter.create<LLVM::ConstantOp>(
476  loc, llvmType,
477  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
478  llvmType.getIntOrFloatBitWidth())));
479 }
480 
481 /// Create the reduction neutral unsigned int maximum value.
482 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
483  ConversionPatternRewriter &rewriter,
484  Location loc, Type llvmType) {
485  return rewriter.create<LLVM::ConstantOp>(
486  loc, llvmType,
487  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
488  llvmType.getIntOrFloatBitWidth())));
489 }
490 
491 /// Create the reduction neutral fp minimum value.
492 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
493  ConversionPatternRewriter &rewriter,
494  Location loc, Type llvmType) {
495  auto floatType = cast<FloatType>(llvmType);
496  return rewriter.create<LLVM::ConstantOp>(
497  loc, llvmType,
498  rewriter.getFloatAttr(
499  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
500  /*Negative=*/false)));
501 }
502 
503 /// Create the reduction neutral fp maximum value.
504 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
505  ConversionPatternRewriter &rewriter,
506  Location loc, Type llvmType) {
507  auto floatType = cast<FloatType>(llvmType);
508  return rewriter.create<LLVM::ConstantOp>(
509  loc, llvmType,
510  rewriter.getFloatAttr(
511  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
512  /*Negative=*/true)));
513 }
514 
515 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
516 /// returns a new accumulator value using `ReductionNeutral`.
517 template <class ReductionNeutral>
518 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
519  Location loc, Type llvmType,
520  Value accumulator) {
521  if (accumulator)
522  return accumulator;
523 
524  return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
525  llvmType);
526 }
527 
528 /// Creates a value with the 1-D vector shape provided in `llvmType`.
529 /// This is used as effective vector length by some intrinsics supporting
530 /// dynamic vector lengths at runtime.
531 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
532  Location loc, Type llvmType) {
533  VectorType vType = cast<VectorType>(llvmType);
534  auto vShape = vType.getShape();
535  assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
536 
537  Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
538  loc, rewriter.getI32Type(),
539  rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
540 
541  if (!vType.getScalableDims()[0])
542  return baseVecLength;
543 
544  // For a scalable vector type, create and return `vScale * baseVecLength`.
545  Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
546  vScale =
547  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
548  Value scalableVecLength =
549  rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
550  return scalableVecLength;
551 }
552 
553 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
554 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
555 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
556 /// non-null.
557 template <class LLVMRedIntrinOp, class ScalarOp>
558 static Value createIntegerReductionArithmeticOpLowering(
559  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
560  Value vectorOperand, Value accumulator) {
561 
562  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
563 
564  if (accumulator)
565  result = rewriter.create<ScalarOp>(loc, accumulator, result);
566  return result;
567 }
568 
569 /// Helper method to lower a `vector.reduction` operation that performs
570 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
571 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
572 /// the accumulator value if non-null.
573 template <class LLVMRedIntrinOp>
574 static Value createIntegerReductionComparisonOpLowering(
575  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
576  Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
577  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
578  if (accumulator) {
579  Value cmp =
580  rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
581  result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
582  }
583  return result;
584 }
585 
586 namespace {
587 template <typename Source>
588 struct VectorToScalarMapper;
589 template <>
590 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
591  using Type = LLVM::MaximumOp;
592 };
593 template <>
594 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
595  using Type = LLVM::MinimumOp;
596 };
597 template <>
598 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
599  using Type = LLVM::MaxNumOp;
600 };
601 template <>
602 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
603  using Type = LLVM::MinNumOp;
604 };
605 } // namespace
606 
607 template <class LLVMRedIntrinOp>
608 static Value createFPReductionComparisonOpLowering(
609  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
610  Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
611  Value result =
612  rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
613 
614  if (accumulator) {
615  result =
616  rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
617  loc, result, accumulator);
618  }
619 
620  return result;
621 }
622 
623 /// Reduction neutral classes for overloading
624 class MaskNeutralFMaximum {};
625 class MaskNeutralFMinimum {};
626 
627 /// Get the mask neutral floating point maximum value
628 static llvm::APFloat
629 getMaskNeutralValue(MaskNeutralFMaximum,
630  const llvm::fltSemantics &floatSemantics) {
631  return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
632 }
633 /// Get the mask neutral floating point minimum value
634 static llvm::APFloat
635 getMaskNeutralValue(MaskNeutralFMinimum,
636  const llvm::fltSemantics &floatSemantics) {
637  return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
638 }
639 
640 /// Create the mask neutral floating point MLIR vector constant
641 template <typename MaskNeutral>
642 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
643  Location loc, Type llvmType,
644  Type vectorType) {
645  const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
646  auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
647  auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
648  return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
649 }
650 
651 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
652 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
653 /// `fmaximum`/`fminimum`.
654 /// More information: https://github.com/llvm/llvm-project/issues/64940
655 template <class LLVMRedIntrinOp, class MaskNeutral>
656 static Value
657 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
658  Location loc, Type llvmType,
659  Value vectorOperand, Value accumulator,
660  Value mask, LLVM::FastmathFlagsAttr fmf) {
661  const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
662  rewriter, loc, llvmType, vectorOperand.getType());
663  const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
664  loc, mask, vectorOperand, vectorMaskNeutral);
665  return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
666  rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
667 }
668 
669 template <class LLVMRedIntrinOp, class ReductionNeutral>
670 static Value
671 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
672  Type llvmType, Value vectorOperand,
673  Value accumulator, LLVM::FastmathFlagsAttr fmf) {
674  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675  llvmType, accumulator);
676  return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
677  /*startValue=*/accumulator,
678  vectorOperand, fmf);
679 }
680 
681 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
682 /// that requires a start value. This start value format spans across fp
683 /// reductions without mask and all the masked reduction intrinsics.
684 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
685 static Value
686 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
687  Location loc, Type llvmType,
688  Value vectorOperand, Value accumulator) {
689  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
690  llvmType, accumulator);
691  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
692  /*startValue=*/accumulator,
693  vectorOperand);
694 }
695 
696 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
697 static Value lowerPredicatedReductionWithStartValue(
698  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
699  Value vectorOperand, Value accumulator, Value mask) {
700  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
701  llvmType, accumulator);
702  Value vectorLength =
703  createVectorLengthValue(rewriter, loc, vectorOperand.getType());
704  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
705  /*startValue=*/accumulator,
706  vectorOperand, mask, vectorLength);
707 }
708 
709 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
710  class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
711 static Value lowerPredicatedReductionWithStartValue(
712  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
713  Value vectorOperand, Value accumulator, Value mask) {
714  if (llvmType.isIntOrIndex())
715  return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
716  IntReductionNeutral>(
717  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
718 
719  // FP dispatch.
720  return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
721  FPReductionNeutral>(
722  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
723 }
724 
725 /// Conversion pattern for all vector reductions.
726 class VectorReductionOpConversion
727  : public ConvertOpToLLVMPattern<vector::ReductionOp> {
728 public:
729  explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
730  bool reassociateFPRed)
731  : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
732  reassociateFPReductions(reassociateFPRed) {}
733 
734  LogicalResult
735  matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
736  ConversionPatternRewriter &rewriter) const override {
737  auto kind = reductionOp.getKind();
738  Type eltType = reductionOp.getDest().getType();
739  Type llvmType = typeConverter->convertType(eltType);
740  Value operand = adaptor.getVector();
741  Value acc = adaptor.getAcc();
742  Location loc = reductionOp.getLoc();
743 
744  if (eltType.isIntOrIndex()) {
745  // Integer reductions: add/mul/min/max/and/or/xor.
746  Value result;
747  switch (kind) {
748  case vector::CombiningKind::ADD:
749  result =
750  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
751  LLVM::AddOp>(
752  rewriter, loc, llvmType, operand, acc);
753  break;
754  case vector::CombiningKind::MUL:
755  result =
756  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
757  LLVM::MulOp>(
758  rewriter, loc, llvmType, operand, acc);
759  break;
761  result = createIntegerReductionComparisonOpLowering<
762  LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
763  LLVM::ICmpPredicate::ule);
764  break;
765  case vector::CombiningKind::MINSI:
766  result = createIntegerReductionComparisonOpLowering<
767  LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
768  LLVM::ICmpPredicate::sle);
769  break;
770  case vector::CombiningKind::MAXUI:
771  result = createIntegerReductionComparisonOpLowering<
772  LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
773  LLVM::ICmpPredicate::uge);
774  break;
775  case vector::CombiningKind::MAXSI:
776  result = createIntegerReductionComparisonOpLowering<
777  LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
778  LLVM::ICmpPredicate::sge);
779  break;
780  case vector::CombiningKind::AND:
781  result =
782  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
783  LLVM::AndOp>(
784  rewriter, loc, llvmType, operand, acc);
785  break;
786  case vector::CombiningKind::OR:
787  result =
788  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
789  LLVM::OrOp>(
790  rewriter, loc, llvmType, operand, acc);
791  break;
792  case vector::CombiningKind::XOR:
793  result =
794  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
795  LLVM::XOrOp>(
796  rewriter, loc, llvmType, operand, acc);
797  break;
798  default:
799  return failure();
800  }
801  rewriter.replaceOp(reductionOp, result);
802 
803  return success();
804  }
805 
806  if (!isa<FloatType>(eltType))
807  return failure();
808 
809  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
810  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
811  reductionOp.getContext(),
812  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
814  reductionOp.getContext(),
815  fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
816  : LLVM::FastmathFlags::none));
817 
818  // Floating-point reductions: add/mul/min/max
819  Value result;
820  if (kind == vector::CombiningKind::ADD) {
821  result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
822  ReductionNeutralZero>(
823  rewriter, loc, llvmType, operand, acc, fmf);
824  } else if (kind == vector::CombiningKind::MUL) {
825  result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
826  ReductionNeutralFPOne>(
827  rewriter, loc, llvmType, operand, acc, fmf);
828  } else if (kind == vector::CombiningKind::MINIMUMF) {
829  result =
830  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
831  rewriter, loc, llvmType, operand, acc, fmf);
832  } else if (kind == vector::CombiningKind::MAXIMUMF) {
833  result =
834  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
835  rewriter, loc, llvmType, operand, acc, fmf);
836  } else if (kind == vector::CombiningKind::MINNUMF) {
837  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
838  rewriter, loc, llvmType, operand, acc, fmf);
839  } else if (kind == vector::CombiningKind::MAXNUMF) {
840  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
841  rewriter, loc, llvmType, operand, acc, fmf);
842  } else
843  return failure();
844 
845  rewriter.replaceOp(reductionOp, result);
846  return success();
847  }
848 
849 private:
850  const bool reassociateFPReductions;
851 };
852 
853 /// Base class to convert a `vector.mask` operation while matching traits
854 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
855 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
856 /// method performs a second match against the maskable operation `MaskedOp`.
857 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
858 /// implemented by the concrete conversion classes. This method can match
859 /// against specific traits of the `vector.mask` and the maskable operation. It
860 /// must replace the `vector.mask` operation.
861 template <class MaskedOp>
862 class VectorMaskOpConversionBase
863  : public ConvertOpToLLVMPattern<vector::MaskOp> {
864 public:
866 
867  LogicalResult
868  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
869  ConversionPatternRewriter &rewriter) const final {
870  // Match against the maskable operation kind.
871  auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
872  if (!maskedOp)
873  return failure();
874  return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
875  }
876 
877 protected:
878  virtual LogicalResult
879  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
880  vector::MaskableOpInterface maskableOp,
881  ConversionPatternRewriter &rewriter) const = 0;
882 };
883 
884 class MaskedReductionOpConversion
885  : public VectorMaskOpConversionBase<vector::ReductionOp> {
886 
887 public:
888  using VectorMaskOpConversionBase<
889  vector::ReductionOp>::VectorMaskOpConversionBase;
890 
891  LogicalResult matchAndRewriteMaskableOp(
892  vector::MaskOp maskOp, MaskableOpInterface maskableOp,
893  ConversionPatternRewriter &rewriter) const override {
894  auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
895  auto kind = reductionOp.getKind();
896  Type eltType = reductionOp.getDest().getType();
897  Type llvmType = typeConverter->convertType(eltType);
898  Value operand = reductionOp.getVector();
899  Value acc = reductionOp.getAcc();
900  Location loc = reductionOp.getLoc();
901 
902  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
903  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
904  reductionOp.getContext(),
905  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
906 
907  Value result;
908  switch (kind) {
909  case vector::CombiningKind::ADD:
910  result = lowerPredicatedReductionWithStartValue<
911  LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
912  ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
913  maskOp.getMask());
914  break;
915  case vector::CombiningKind::MUL:
916  result = lowerPredicatedReductionWithStartValue<
917  LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
918  ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
919  maskOp.getMask());
920  break;
922  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
923  ReductionNeutralUIntMax>(
924  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
925  break;
926  case vector::CombiningKind::MINSI:
927  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
928  ReductionNeutralSIntMax>(
929  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
930  break;
931  case vector::CombiningKind::MAXUI:
932  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
933  ReductionNeutralUIntMin>(
934  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
935  break;
936  case vector::CombiningKind::MAXSI:
937  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
938  ReductionNeutralSIntMin>(
939  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
940  break;
941  case vector::CombiningKind::AND:
942  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
943  ReductionNeutralAllOnes>(
944  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
945  break;
946  case vector::CombiningKind::OR:
947  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
948  ReductionNeutralZero>(
949  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
950  break;
951  case vector::CombiningKind::XOR:
952  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
953  ReductionNeutralZero>(
954  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
955  break;
956  case vector::CombiningKind::MINNUMF:
957  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
958  ReductionNeutralFPMax>(
959  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
960  break;
961  case vector::CombiningKind::MAXNUMF:
962  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
963  ReductionNeutralFPMin>(
964  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
965  break;
966  case CombiningKind::MAXIMUMF:
967  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
968  MaskNeutralFMaximum>(
969  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
970  break;
971  case CombiningKind::MINIMUMF:
972  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
973  MaskNeutralFMinimum>(
974  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
975  break;
976  }
977 
978  // Replace `vector.mask` operation altogether.
979  rewriter.replaceOp(maskOp, result);
980  return success();
981  }
982 };
983 
984 class VectorShuffleOpConversion
985  : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
986 public:
988 
989  LogicalResult
990  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
991  ConversionPatternRewriter &rewriter) const override {
992  auto loc = shuffleOp->getLoc();
993  auto v1Type = shuffleOp.getV1VectorType();
994  auto v2Type = shuffleOp.getV2VectorType();
995  auto vectorType = shuffleOp.getResultVectorType();
996  Type llvmType = typeConverter->convertType(vectorType);
997  ArrayRef<int64_t> mask = shuffleOp.getMask();
998 
999  // Bail if result type cannot be lowered.
1000  if (!llvmType)
1001  return failure();
1002 
1003  // Get rank and dimension sizes.
1004  int64_t rank = vectorType.getRank();
1005 #ifndef NDEBUG
1006  bool wellFormed0DCase =
1007  v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1008  bool wellFormedNDCase =
1009  v1Type.getRank() == rank && v2Type.getRank() == rank;
1010  assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1011 #endif
1012 
1013  // For rank 0 and 1, where both operands have *exactly* the same vector
1014  // type, there is direct shuffle support in LLVM. Use it!
1015  if (rank <= 1 && v1Type == v2Type) {
1016  Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1017  loc, adaptor.getV1(), adaptor.getV2(),
1018  llvm::to_vector_of<int32_t>(mask));
1019  rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1020  return success();
1021  }
1022 
1023  // For all other cases, insert the individual values individually.
1024  int64_t v1Dim = v1Type.getDimSize(0);
1025  Type eltType;
1026  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1027  eltType = arrayType.getElementType();
1028  else
1029  eltType = cast<VectorType>(llvmType).getElementType();
1030  Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1031  int64_t insPos = 0;
1032  for (int64_t extPos : mask) {
1033  Value value = adaptor.getV1();
1034  if (extPos >= v1Dim) {
1035  extPos -= v1Dim;
1036  value = adaptor.getV2();
1037  }
1038  Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1039  eltType, rank, extPos);
1040  insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1041  llvmType, rank, insPos++);
1042  }
1043  rewriter.replaceOp(shuffleOp, insert);
1044  return success();
1045  }
1046 };
1047 
1048 class VectorExtractElementOpConversion
1049  : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1050 public:
1051  using ConvertOpToLLVMPattern<
1052  vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1053 
1054  LogicalResult
1055  matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1056  ConversionPatternRewriter &rewriter) const override {
1057  auto vectorType = extractEltOp.getSourceVectorType();
1058  auto llvmType = typeConverter->convertType(vectorType.getElementType());
1059 
1060  // Bail if result type cannot be lowered.
1061  if (!llvmType)
1062  return failure();
1063 
1064  if (vectorType.getRank() == 0) {
1065  Location loc = extractEltOp.getLoc();
1066  auto idxType = rewriter.getIndexType();
1067  auto zero = rewriter.create<LLVM::ConstantOp>(
1068  loc, typeConverter->convertType(idxType),
1069  rewriter.getIntegerAttr(idxType, 0));
1070  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1071  extractEltOp, llvmType, adaptor.getVector(), zero);
1072  return success();
1073  }
1074 
1075  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1076  extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1077  return success();
1078  }
1079 };
1080 
1081 class VectorExtractOpConversion
1082  : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1083 public:
1085 
1086  LogicalResult
1087  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1088  ConversionPatternRewriter &rewriter) const override {
1089  auto loc = extractOp->getLoc();
1090  auto resultType = extractOp.getResult().getType();
1091  auto llvmResultType = typeConverter->convertType(resultType);
1092  // Bail if result type cannot be lowered.
1093  if (!llvmResultType)
1094  return failure();
1095 
1097  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1098 
1099  // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1100  // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1101  // from a N-d vector extract to a nested aggregate vector extract in two
1102  // steps:
1103  // - Extract a member from the nested aggregate. The result can be
1104  // a lower rank nested aggregate or a vector (1-D). This is done using
1105  // `llvm.extractvalue`.
1106  // - Extract a scalar out of the vector if needed. This is done using
1107  // `llvm.extractelement`.
1108 
1109  // Determine if we need to extract a member out of the aggregate. We
1110  // always need to extract a member if the input rank >= 2.
1111  bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1112  // Determine if we need to extract a scalar as the result. We extract
1113  // a scalar if the extract is full rank, i.e., the number of indices is
1114  // equal to source vector rank.
1115  bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1116  extractOp.getSourceVectorType().getRank();
1117 
1118  // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1119  // need to add a position for this change.
1120  if (extractOp.getSourceVectorType().getRank() == 0) {
1121  Type idxType = typeConverter->convertType(rewriter.getIndexType());
1122  positionVec.push_back(rewriter.getZeroAttr(idxType));
1123  }
1124 
1125  Value extracted = adaptor.getVector();
1126  if (extractsAggregate) {
1127  ArrayRef<OpFoldResult> position(positionVec);
1128  if (extractsScalar) {
1129  // If we are extracting a scalar from the extracted member, we drop
1130  // the last index, which will be used to extract the scalar out of the
1131  // vector.
1132  position = position.drop_back();
1133  }
1134  // llvm.extractvalue does not support dynamic dimensions.
1135  if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1136  return failure();
1137  }
1138  extracted = rewriter.create<LLVM::ExtractValueOp>(
1139  loc, extracted, getAsIntegers(position));
1140  }
1141 
1142  if (extractsScalar) {
1143  extracted = rewriter.create<LLVM::ExtractElementOp>(
1144  loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
1145  }
1146 
1147  rewriter.replaceOp(extractOp, extracted);
1148  return success();
1149  }
1150 };
1151 
1152 /// Conversion pattern that turns a vector.fma on a 1-D vector
1153 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1154 /// This does not match vectors of n >= 2 rank.
1155 ///
1156 /// Example:
1157 /// ```
1158 /// vector.fma %a, %a, %a : vector<8xf32>
1159 /// ```
1160 /// is converted to:
1161 /// ```
1162 /// llvm.intr.fmuladd %va, %va, %va:
1163 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1164 /// -> !llvm."<8 x f32>">
1165 /// ```
1166 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1167 public:
1169 
1170  LogicalResult
1171  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1172  ConversionPatternRewriter &rewriter) const override {
1173  VectorType vType = fmaOp.getVectorType();
1174  if (vType.getRank() > 1)
1175  return failure();
1176 
1177  rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1178  fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1179  return success();
1180  }
1181 };
1182 
1183 class VectorInsertElementOpConversion
1184  : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1185 public:
1187 
1188  LogicalResult
1189  matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1190  ConversionPatternRewriter &rewriter) const override {
1191  auto vectorType = insertEltOp.getDestVectorType();
1192  auto llvmType = typeConverter->convertType(vectorType);
1193 
1194  // Bail if result type cannot be lowered.
1195  if (!llvmType)
1196  return failure();
1197 
1198  if (vectorType.getRank() == 0) {
1199  Location loc = insertEltOp.getLoc();
1200  auto idxType = rewriter.getIndexType();
1201  auto zero = rewriter.create<LLVM::ConstantOp>(
1202  loc, typeConverter->convertType(idxType),
1203  rewriter.getIntegerAttr(idxType, 0));
1204  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1205  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1206  return success();
1207  }
1208 
1209  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1210  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1211  adaptor.getPosition());
1212  return success();
1213  }
1214 };
1215 
1216 class VectorInsertOpConversion
1217  : public ConvertOpToLLVMPattern<vector::InsertOp> {
1218 public:
1220 
1221  LogicalResult
1222  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1223  ConversionPatternRewriter &rewriter) const override {
1224  auto loc = insertOp->getLoc();
1225  auto sourceType = insertOp.getSourceType();
1226  auto destVectorType = insertOp.getDestVectorType();
1227  auto llvmResultType = typeConverter->convertType(destVectorType);
1228  // Bail if result type cannot be lowered.
1229  if (!llvmResultType)
1230  return failure();
1231 
1233  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1234 
1235  // Overwrite entire vector with value. Should be handled by folder, but
1236  // just to be safe.
1237  ArrayRef<OpFoldResult> position(positionVec);
1238  if (position.empty()) {
1239  rewriter.replaceOp(insertOp, adaptor.getSource());
1240  return success();
1241  }
1242 
1243  // One-shot insertion of a vector into an array (only requires insertvalue).
1244  if (isa<VectorType>(sourceType)) {
1245  if (insertOp.hasDynamicPosition())
1246  return failure();
1247 
1248  Value inserted = rewriter.create<LLVM::InsertValueOp>(
1249  loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1250  rewriter.replaceOp(insertOp, inserted);
1251  return success();
1252  }
1253 
1254  // Potential extraction of 1-D vector from array.
1255  Value extracted = adaptor.getDest();
1256  auto oneDVectorType = destVectorType;
1257  if (position.size() > 1) {
1258  if (insertOp.hasDynamicPosition())
1259  return failure();
1260 
1261  oneDVectorType = reducedVectorTypeBack(destVectorType);
1262  extracted = rewriter.create<LLVM::ExtractValueOp>(
1263  loc, extracted, getAsIntegers(position.drop_back()));
1264  }
1265 
1266  // Insertion of an element into a 1-D LLVM vector.
1267  Value inserted = rewriter.create<LLVM::InsertElementOp>(
1268  loc, typeConverter->convertType(oneDVectorType), extracted,
1269  adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1270 
1271  // Potential insertion of resulting 1-D vector into array.
1272  if (position.size() > 1) {
1273  if (insertOp.hasDynamicPosition())
1274  return failure();
1275 
1276  inserted = rewriter.create<LLVM::InsertValueOp>(
1277  loc, adaptor.getDest(), inserted,
1278  getAsIntegers(position.drop_back()));
1279  }
1280 
1281  rewriter.replaceOp(insertOp, inserted);
1282  return success();
1283  }
1284 };
1285 
1286 /// Lower vector.scalable.insert ops to LLVM vector.insert
1287 struct VectorScalableInsertOpLowering
1288  : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1289  using ConvertOpToLLVMPattern<
1290  vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1291 
1292  LogicalResult
1293  matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1294  ConversionPatternRewriter &rewriter) const override {
1295  rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1296  insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1297  return success();
1298  }
1299 };
1300 
1301 /// Lower vector.scalable.extract ops to LLVM vector.extract
1302 struct VectorScalableExtractOpLowering
1303  : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1304  using ConvertOpToLLVMPattern<
1305  vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1306 
1307  LogicalResult
1308  matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1309  ConversionPatternRewriter &rewriter) const override {
1310  rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1311  extOp, typeConverter->convertType(extOp.getResultVectorType()),
1312  adaptor.getSource(), adaptor.getPos());
1313  return success();
1314  }
1315 };
1316 
1317 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1318 ///
1319 /// Example:
1320 /// ```
1321 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1322 /// ```
1323 /// is rewritten into:
1324 /// ```
1325 /// %r = splat %f0: vector<2x4xf32>
1326 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1327 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1328 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1329 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1330 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1331 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1332 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1333 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1334 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1335 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1336 /// // %r3 holds the final value.
1337 /// ```
1338 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1339 public:
1341 
1342  void initialize() {
1343  // This pattern recursively unpacks one dimension at a time. The recursion
1344  // bounded as the rank is strictly decreasing.
1345  setHasBoundedRewriteRecursion();
1346  }
1347 
1348  LogicalResult matchAndRewrite(FMAOp op,
1349  PatternRewriter &rewriter) const override {
1350  auto vType = op.getVectorType();
1351  if (vType.getRank() < 2)
1352  return failure();
1353 
1354  auto loc = op.getLoc();
1355  auto elemType = vType.getElementType();
1356  Value zero = rewriter.create<arith::ConstantOp>(
1357  loc, elemType, rewriter.getZeroAttr(elemType));
1358  Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1359  for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1360  Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1361  Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1362  Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1363  Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1364  desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1365  }
1366  rewriter.replaceOp(op, desc);
1367  return success();
1368  }
1369 };
1370 
1371 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1372 /// static layout.
1373 static std::optional<SmallVector<int64_t, 4>>
1374 computeContiguousStrides(MemRefType memRefType) {
1375  int64_t offset;
1376  SmallVector<int64_t, 4> strides;
1377  if (failed(getStridesAndOffset(memRefType, strides, offset)))
1378  return std::nullopt;
1379  if (!strides.empty() && strides.back() != 1)
1380  return std::nullopt;
1381  // If no layout or identity layout, this is contiguous by definition.
1382  if (memRefType.getLayout().isIdentity())
1383  return strides;
1384 
1385  // Otherwise, we must determine contiguity form shapes. This can only ever
1386  // work in static cases because MemRefType is underspecified to represent
1387  // contiguous dynamic shapes in other ways than with just empty/identity
1388  // layout.
1389  auto sizes = memRefType.getShape();
1390  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1391  if (ShapedType::isDynamic(sizes[index + 1]) ||
1392  ShapedType::isDynamic(strides[index]) ||
1393  ShapedType::isDynamic(strides[index + 1]))
1394  return std::nullopt;
1395  if (strides[index] != strides[index + 1] * sizes[index + 1])
1396  return std::nullopt;
1397  }
1398  return strides;
1399 }
1400 
1401 class VectorTypeCastOpConversion
1402  : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1403 public:
1405 
1406  LogicalResult
1407  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1408  ConversionPatternRewriter &rewriter) const override {
1409  auto loc = castOp->getLoc();
1410  MemRefType sourceMemRefType =
1411  cast<MemRefType>(castOp.getOperand().getType());
1412  MemRefType targetMemRefType = castOp.getType();
1413 
1414  // Only static shape casts supported atm.
1415  if (!sourceMemRefType.hasStaticShape() ||
1416  !targetMemRefType.hasStaticShape())
1417  return failure();
1418 
1419  auto llvmSourceDescriptorTy =
1420  dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1421  if (!llvmSourceDescriptorTy)
1422  return failure();
1423  MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1424 
1425  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1426  typeConverter->convertType(targetMemRefType));
1427  if (!llvmTargetDescriptorTy)
1428  return failure();
1429 
1430  // Only contiguous source buffers supported atm.
1431  auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1432  if (!sourceStrides)
1433  return failure();
1434  auto targetStrides = computeContiguousStrides(targetMemRefType);
1435  if (!targetStrides)
1436  return failure();
1437  // Only support static strides for now, regardless of contiguity.
1438  if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1439  return failure();
1440 
1441  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1442 
1443  // Create descriptor.
1444  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1445  // Set allocated ptr.
1446  Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1447  desc.setAllocatedPtr(rewriter, loc, allocated);
1448 
1449  // Set aligned ptr.
1450  Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1451  desc.setAlignedPtr(rewriter, loc, ptr);
1452  // Fill offset 0.
1453  auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1454  auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1455  desc.setOffset(rewriter, loc, zero);
1456 
1457  // Fill size and stride descriptors in memref.
1458  for (const auto &indexedSize :
1459  llvm::enumerate(targetMemRefType.getShape())) {
1460  int64_t index = indexedSize.index();
1461  auto sizeAttr =
1462  rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1463  auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1464  desc.setSize(rewriter, loc, index, size);
1465  auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1466  (*targetStrides)[index]);
1467  auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1468  desc.setStride(rewriter, loc, index, stride);
1469  }
1470 
1471  rewriter.replaceOp(castOp, {desc});
1472  return success();
1473  }
1474 };
1475 
1476 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1477 /// Non-scalable versions of this operation are handled in Vector Transforms.
1478 class VectorCreateMaskOpConversion
1479  : public OpConversionPattern<vector::CreateMaskOp> {
1480 public:
1481  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1482  bool enableIndexOpt)
1483  : OpConversionPattern<vector::CreateMaskOp>(context),
1484  force32BitVectorIndices(enableIndexOpt) {}
1485 
1486  LogicalResult
1487  matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1488  ConversionPatternRewriter &rewriter) const override {
1489  auto dstType = op.getType();
1490  if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1491  return failure();
1492  IntegerType idxType =
1493  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1494  auto loc = op->getLoc();
1495  Value indices = rewriter.create<LLVM::StepVectorOp>(
1496  loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1497  /*isScalable=*/true));
1498  auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1499  adaptor.getOperands()[0]);
1500  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1501  Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1502  indices, bounds);
1503  rewriter.replaceOp(op, comp);
1504  return success();
1505  }
1506 
1507 private:
1508  const bool force32BitVectorIndices;
1509 };
1510 
1511 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1512 public:
1514 
1515  // Lowering implementation that relies on a small runtime support library,
1516  // which only needs to provide a few printing methods (single value for all
1517  // data types, opening/closing bracket, comma, newline). The lowering splits
1518  // the vector into elementary printing operations. The advantage of this
1519  // approach is that the library can remain unaware of all low-level
1520  // implementation details of vectors while still supporting output of any
1521  // shaped and dimensioned vector.
1522  //
1523  // Note: This lowering only handles scalars, n-D vectors are broken into
1524  // printing scalars in loops in VectorToSCF.
1525  //
1526  // TODO: rely solely on libc in future? something else?
1527  //
1528  LogicalResult
1529  matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1530  ConversionPatternRewriter &rewriter) const override {
1531  auto parent = printOp->getParentOfType<ModuleOp>();
1532  if (!parent)
1533  return failure();
1534 
1535  auto loc = printOp->getLoc();
1536 
1537  if (auto value = adaptor.getSource()) {
1538  Type printType = printOp.getPrintType();
1539  if (isa<VectorType>(printType)) {
1540  // Vectors should be broken into elementary print ops in VectorToSCF.
1541  return failure();
1542  }
1543  if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1544  return failure();
1545  }
1546 
1547  auto punct = printOp.getPunctuation();
1548  if (auto stringLiteral = printOp.getStringLiteral()) {
1549  LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1550  *stringLiteral, *getTypeConverter(),
1551  /*addNewline=*/false);
1552  } else if (punct != PrintPunctuation::NoPunctuation) {
1553  emitCall(rewriter, printOp->getLoc(), [&] {
1554  switch (punct) {
1555  case PrintPunctuation::Close:
1556  return LLVM::lookupOrCreatePrintCloseFn(parent);
1557  case PrintPunctuation::Open:
1558  return LLVM::lookupOrCreatePrintOpenFn(parent);
1559  case PrintPunctuation::Comma:
1560  return LLVM::lookupOrCreatePrintCommaFn(parent);
1561  case PrintPunctuation::NewLine:
1562  return LLVM::lookupOrCreatePrintNewlineFn(parent);
1563  default:
1564  llvm_unreachable("unexpected punctuation");
1565  }
1566  }());
1567  }
1568 
1569  rewriter.eraseOp(printOp);
1570  return success();
1571  }
1572 
1573 private:
1574  enum class PrintConversion {
1575  // clang-format off
1576  None,
1577  ZeroExt64,
1578  SignExt64,
1579  Bitcast16
1580  // clang-format on
1581  };
1582 
1583  LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1584  ModuleOp parent, Location loc, Type printType,
1585  Value value) const {
1586  if (typeConverter->convertType(printType) == nullptr)
1587  return failure();
1588 
1589  // Make sure element type has runtime support.
1590  PrintConversion conversion = PrintConversion::None;
1591  Operation *printer;
1592  if (printType.isF32()) {
1593  printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1594  } else if (printType.isF64()) {
1595  printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1596  } else if (printType.isF16()) {
1597  conversion = PrintConversion::Bitcast16; // bits!
1598  printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1599  } else if (printType.isBF16()) {
1600  conversion = PrintConversion::Bitcast16; // bits!
1601  printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1602  } else if (printType.isIndex()) {
1603  printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1604  } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1605  // Integers need a zero or sign extension on the operand
1606  // (depending on the source type) as well as a signed or
1607  // unsigned print method. Up to 64-bit is supported.
1608  unsigned width = intTy.getWidth();
1609  if (intTy.isUnsigned()) {
1610  if (width <= 64) {
1611  if (width < 64)
1612  conversion = PrintConversion::ZeroExt64;
1613  printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1614  } else {
1615  return failure();
1616  }
1617  } else {
1618  assert(intTy.isSignless() || intTy.isSigned());
1619  if (width <= 64) {
1620  // Note that we *always* zero extend booleans (1-bit integers),
1621  // so that true/false is printed as 1/0 rather than -1/0.
1622  if (width == 1)
1623  conversion = PrintConversion::ZeroExt64;
1624  else if (width < 64)
1625  conversion = PrintConversion::SignExt64;
1626  printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1627  } else {
1628  return failure();
1629  }
1630  }
1631  } else {
1632  return failure();
1633  }
1634 
1635  switch (conversion) {
1636  case PrintConversion::ZeroExt64:
1637  value = rewriter.create<arith::ExtUIOp>(
1638  loc, IntegerType::get(rewriter.getContext(), 64), value);
1639  break;
1640  case PrintConversion::SignExt64:
1641  value = rewriter.create<arith::ExtSIOp>(
1642  loc, IntegerType::get(rewriter.getContext(), 64), value);
1643  break;
1644  case PrintConversion::Bitcast16:
1645  value = rewriter.create<LLVM::BitcastOp>(
1646  loc, IntegerType::get(rewriter.getContext(), 16), value);
1647  break;
1648  case PrintConversion::None:
1649  break;
1650  }
1651  emitCall(rewriter, loc, printer, value);
1652  return success();
1653  }
1654 
1655  // Helper to emit a call.
1656  static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1657  Operation *ref, ValueRange params = ValueRange()) {
1658  rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1659  params);
1660  }
1661 };
1662 
1663 /// The Splat operation is lowered to an insertelement + a shufflevector
1664 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1665 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1667 
1668  LogicalResult
1669  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1670  ConversionPatternRewriter &rewriter) const override {
1671  VectorType resultType = cast<VectorType>(splatOp.getType());
1672  if (resultType.getRank() > 1)
1673  return failure();
1674 
1675  // First insert it into an undef vector so we can shuffle it.
1676  auto vectorType = typeConverter->convertType(splatOp.getType());
1677  Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1678  auto zero = rewriter.create<LLVM::ConstantOp>(
1679  splatOp.getLoc(),
1680  typeConverter->convertType(rewriter.getIntegerType(32)),
1681  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1682 
1683  // For 0-d vector, we simply do `insertelement`.
1684  if (resultType.getRank() == 0) {
1685  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1686  splatOp, vectorType, undef, adaptor.getInput(), zero);
1687  return success();
1688  }
1689 
1690  // For 1-d vector, we additionally do a `vectorshuffle`.
1691  auto v = rewriter.create<LLVM::InsertElementOp>(
1692  splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1693 
1694  int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1695  SmallVector<int32_t> zeroValues(width, 0);
1696 
1697  // Shuffle the value across the desired number of elements.
1698  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1699  zeroValues);
1700  return success();
1701  }
1702 };
1703 
1704 /// The Splat operation is lowered to an insertelement + a shufflevector
1705 /// operation. Splat to only 2+-d vector result types are lowered by the
1706 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1707 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1709 
1710  LogicalResult
1711  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1712  ConversionPatternRewriter &rewriter) const override {
1713  VectorType resultType = splatOp.getType();
1714  if (resultType.getRank() <= 1)
1715  return failure();
1716 
1717  // First insert it into an undef vector so we can shuffle it.
1718  auto loc = splatOp.getLoc();
1719  auto vectorTypeInfo =
1720  LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1721  auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1722  auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1723  if (!llvmNDVectorTy || !llvm1DVectorTy)
1724  return failure();
1725 
1726  // Construct returned value.
1727  Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1728 
1729  // Construct a 1-D vector with the splatted value that we insert in all the
1730  // places within the returned descriptor.
1731  Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1732  auto zero = rewriter.create<LLVM::ConstantOp>(
1733  loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1734  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1735  Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1736  adaptor.getInput(), zero);
1737 
1738  // Shuffle the value across the desired number of elements.
1739  int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1740  SmallVector<int32_t> zeroValues(width, 0);
1741  v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1742 
1743  // Iterate of linear index, convert to coords space and insert splatted 1-D
1744  // vector in each position.
1745  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1746  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1747  });
1748  rewriter.replaceOp(splatOp, desc);
1749  return success();
1750  }
1751 };
1752 
1753 /// Conversion pattern for a `vector.interleave`.
1754 /// This supports fixed-sized vectors and scalable vectors.
1755 struct VectorInterleaveOpLowering
1756  : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1758 
1759  LogicalResult
1760  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1761  ConversionPatternRewriter &rewriter) const override {
1762  VectorType resultType = interleaveOp.getResultVectorType();
1763  // n-D interleaves should have been lowered already.
1764  if (resultType.getRank() != 1)
1765  return rewriter.notifyMatchFailure(interleaveOp,
1766  "InterleaveOp not rank 1");
1767  // If the result is rank 1, then this directly maps to LLVM.
1768  if (resultType.isScalable()) {
1769  rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1770  interleaveOp, typeConverter->convertType(resultType),
1771  adaptor.getLhs(), adaptor.getRhs());
1772  return success();
1773  }
1774  // Lower fixed-size interleaves to a shufflevector. While the
1775  // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1776  // langref still recommends fixed-vectors use shufflevector, see:
1777  // https://llvm.org/docs/LangRef.html#id876.
1778  int64_t resultVectorSize = resultType.getNumElements();
1779  SmallVector<int32_t> interleaveShuffleMask;
1780  interleaveShuffleMask.reserve(resultVectorSize);
1781  for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1782  interleaveShuffleMask.push_back(i);
1783  interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1784  }
1785  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1786  interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1787  interleaveShuffleMask);
1788  return success();
1789  }
1790 };
1791 
1792 /// Conversion pattern for a `vector.deinterleave`.
1793 /// This supports fixed-sized vectors and scalable vectors.
1794 struct VectorDeinterleaveOpLowering
1795  : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1797 
1798  LogicalResult
1799  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1800  ConversionPatternRewriter &rewriter) const override {
1801  VectorType resultType = deinterleaveOp.getResultVectorType();
1802  VectorType sourceType = deinterleaveOp.getSourceVectorType();
1803  auto loc = deinterleaveOp.getLoc();
1804 
1805  // Note: n-D deinterleave operations should be lowered to the 1-D before
1806  // converting to LLVM.
1807  if (resultType.getRank() != 1)
1808  return rewriter.notifyMatchFailure(deinterleaveOp,
1809  "DeinterleaveOp not rank 1");
1810 
1811  if (resultType.isScalable()) {
1812  auto llvmTypeConverter = this->getTypeConverter();
1813  auto deinterleaveResults = deinterleaveOp.getResultTypes();
1814  auto packedOpResults =
1815  llvmTypeConverter->packOperationResults(deinterleaveResults);
1816  auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1817  loc, packedOpResults, adaptor.getSource());
1818 
1819  auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1820  loc, intrinsic->getResult(0), 0);
1821  auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1822  loc, intrinsic->getResult(0), 1);
1823 
1824  rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1825  return success();
1826  }
1827  // Lower fixed-size deinterleave to two shufflevectors. While the
1828  // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1829  // langref still recommends fixed-vectors use shufflevector, see:
1830  // https://llvm.org/docs/LangRef.html#id889.
1831  int64_t resultVectorSize = resultType.getNumElements();
1832  SmallVector<int32_t> evenShuffleMask;
1833  SmallVector<int32_t> oddShuffleMask;
1834 
1835  evenShuffleMask.reserve(resultVectorSize);
1836  oddShuffleMask.reserve(resultVectorSize);
1837 
1838  for (int i = 0; i < sourceType.getNumElements(); ++i) {
1839  if (i % 2 == 0)
1840  evenShuffleMask.push_back(i);
1841  else
1842  oddShuffleMask.push_back(i);
1843  }
1844 
1845  auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1846  auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1847  loc, adaptor.getSource(), poison, evenShuffleMask);
1848  auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1849  loc, adaptor.getSource(), poison, oddShuffleMask);
1850 
1851  rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1852  return success();
1853  }
1854 };
1855 
1856 /// Conversion pattern for a `vector.from_elements`.
1857 struct VectorFromElementsLowering
1858  : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1860 
1861  LogicalResult
1862  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1863  ConversionPatternRewriter &rewriter) const override {
1864  Location loc = fromElementsOp.getLoc();
1865  VectorType vectorType = fromElementsOp.getType();
1866  // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1867  // Such ops should be handled in the same way as vector.insert.
1868  if (vectorType.getRank() > 1)
1869  return rewriter.notifyMatchFailure(fromElementsOp,
1870  "rank > 1 vectors are not supported");
1871  Type llvmType = typeConverter->convertType(vectorType);
1872  Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1873  for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
1874  result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
1875  rewriter.replaceOp(fromElementsOp, result);
1876  return success();
1877  }
1878 };
1879 
1880 /// Conversion pattern for vector.step.
1881 struct VectorScalableStepOpLowering
1882  : public ConvertOpToLLVMPattern<vector::StepOp> {
1884 
1885  LogicalResult
1886  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1887  ConversionPatternRewriter &rewriter) const override {
1888  auto resultType = cast<VectorType>(stepOp.getType());
1889  if (!resultType.isScalable()) {
1890  return failure();
1891  }
1892  Type llvmType = typeConverter->convertType(stepOp.getType());
1893  rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1894  return success();
1895  }
1896 };
1897 
1898 } // namespace
1899 
1902  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
1903 }
1904 
1905 /// Populate the given list with patterns that convert from Vector to LLVM.
1907  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1908  bool reassociateFPReductions, bool force32BitVectorIndices) {
1909  // This function populates only ConversionPatterns, not RewritePatterns.
1910  MLIRContext *ctx = converter.getDialect()->getContext();
1911  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1912  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
1913  patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1914  VectorExtractElementOpConversion, VectorExtractOpConversion,
1915  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1916  VectorInsertOpConversion, VectorPrintOpConversion,
1917  VectorTypeCastOpConversion, VectorScaleOpConversion,
1918  VectorLoadStoreConversion<vector::LoadOp>,
1919  VectorLoadStoreConversion<vector::MaskedLoadOp>,
1920  VectorLoadStoreConversion<vector::StoreOp>,
1921  VectorLoadStoreConversion<vector::MaskedStoreOp>,
1922  VectorGatherOpConversion, VectorScatterOpConversion,
1923  VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1924  VectorSplatOpLowering, VectorSplatNdOpLowering,
1925  VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1926  MaskedReductionOpConversion, VectorInterleaveOpLowering,
1927  VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1928  VectorScalableStepOpLowering>(converter);
1929 }
1930 
1932  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1933  patterns.add<VectorMatmulOpConversion>(converter);
1934  patterns.add<VectorFlatTransposeOpConversion>(converter);
1935 }
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
@ None
#define MINUI(lhs, rhs)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:19
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:149
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition: TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
Definition: TypeToLLVM.cpp:196
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:225
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:123
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:930
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:314
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358