MLIR  18.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/TypeUtilities.h"
30 #include "llvm/ADT/APFloat.h"
31 #include "llvm/Support/Casting.h"
32 #include <optional>
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 // Helper to reduce vector type by *all* but one rank at back.
38 static VectorType reducedVectorTypeBack(VectorType tp) {
39  assert((tp.getRank() > 1) && "unlowerable vector type");
40  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
41  tp.getScalableDims().take_back());
42 }
43 
44 // Helper that picks the proper sequence for inserting.
46  const LLVMTypeConverter &typeConverter, Location loc,
47  Value val1, Value val2, Type llvmType, int64_t rank,
48  int64_t pos) {
49  assert(rank > 0 && "0-D vector corner case should have been handled already");
50  if (rank == 1) {
51  auto idxType = rewriter.getIndexType();
52  auto constant = rewriter.create<LLVM::ConstantOp>(
53  loc, typeConverter.convertType(idxType),
54  rewriter.getIntegerAttr(idxType, pos));
55  return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
56  constant);
57  }
58  return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
59 }
60 
61 // Helper that picks the proper sequence for extracting.
63  const LLVMTypeConverter &typeConverter, Location loc,
64  Value val, Type llvmType, int64_t rank, int64_t pos) {
65  if (rank <= 1) {
66  auto idxType = rewriter.getIndexType();
67  auto constant = rewriter.create<LLVM::ConstantOp>(
68  loc, typeConverter.convertType(idxType),
69  rewriter.getIntegerAttr(idxType, pos));
70  return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
71  constant);
72  }
73  return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
74 }
75 
76 // Helper that returns data layout alignment of a memref.
78  MemRefType memrefType, unsigned &align) {
79  Type elementTy = typeConverter.convertType(memrefType.getElementType());
80  if (!elementTy)
81  return failure();
82 
83  // TODO: this should use the MLIR data layout when it becomes available and
84  // stop depending on translation.
85  llvm::LLVMContext llvmContext;
86  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
87  .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
88  return success();
89 }
90 
91 // Check if the last stride is non-unit and has a valid memory space.
92 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
93  const LLVMTypeConverter &converter) {
94  if (!isLastMemrefDimUnitStride(memRefType))
95  return failure();
96  if (failed(converter.getMemRefAddressSpace(memRefType)))
97  return failure();
98  return success();
99 }
100 
101 // Add an index vector component to a base pointer.
103  const LLVMTypeConverter &typeConverter,
104  MemRefType memRefType, Value llvmMemref, Value base,
105  Value index, uint64_t vLen) {
106  assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
107  "unsupported memref type");
108  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
109  auto ptrsType = LLVM::getFixedVectorType(pType, vLen);
110  return rewriter.create<LLVM::GEPOp>(
111  loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
112  base, index);
113 }
114 
115 /// Convert `foldResult` into a Value. Integer attribute is converted to
116 /// an LLVM constant op.
117 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
118  OpFoldResult foldResult) {
119  if (auto attr = foldResult.dyn_cast<Attribute>()) {
120  auto intAttr = cast<IntegerAttr>(attr);
121  return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
122  }
123 
124  return foldResult.get<Value>();
125 }
126 
127 namespace {
128 
129 /// Trivial Vector to LLVM conversions
130 using VectorScaleOpConversion =
132 
133 /// Conversion pattern for a vector.bitcast.
134 class VectorBitCastOpConversion
135  : public ConvertOpToLLVMPattern<vector::BitCastOp> {
136 public:
138 
140  matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
141  ConversionPatternRewriter &rewriter) const override {
142  // Only 0-D and 1-D vectors can be lowered to LLVM.
143  VectorType resultTy = bitCastOp.getResultVectorType();
144  if (resultTy.getRank() > 1)
145  return failure();
146  Type newResultTy = typeConverter->convertType(resultTy);
147  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
148  adaptor.getOperands()[0]);
149  return success();
150  }
151 };
152 
153 /// Conversion pattern for a vector.matrix_multiply.
154 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
155 class VectorMatmulOpConversion
156  : public ConvertOpToLLVMPattern<vector::MatmulOp> {
157 public:
159 
161  matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
162  ConversionPatternRewriter &rewriter) const override {
163  rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
164  matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
165  adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
166  matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
167  return success();
168  }
169 };
170 
171 /// Conversion pattern for a vector.flat_transpose.
172 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
173 class VectorFlatTransposeOpConversion
174  : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
175 public:
177 
179  matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
180  ConversionPatternRewriter &rewriter) const override {
181  rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
182  transOp, typeConverter->convertType(transOp.getRes().getType()),
183  adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
184  return success();
185  }
186 };
187 
188 /// Overloaded utility that replaces a vector.load, vector.store,
189 /// vector.maskedload and vector.maskedstore with their respective LLVM
190 /// couterparts.
191 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
192  vector::LoadOpAdaptor adaptor,
193  VectorType vectorTy, Value ptr, unsigned align,
194  ConversionPatternRewriter &rewriter) {
195  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align);
196 }
197 
198 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
199  vector::MaskedLoadOpAdaptor adaptor,
200  VectorType vectorTy, Value ptr, unsigned align,
201  ConversionPatternRewriter &rewriter) {
202  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
203  loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
204 }
205 
206 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
207  vector::StoreOpAdaptor adaptor,
208  VectorType vectorTy, Value ptr, unsigned align,
209  ConversionPatternRewriter &rewriter) {
210  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
211  ptr, align);
212 }
213 
214 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
215  vector::MaskedStoreOpAdaptor adaptor,
216  VectorType vectorTy, Value ptr, unsigned align,
217  ConversionPatternRewriter &rewriter) {
218  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
219  storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
220 }
221 
222 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
223 /// vector.maskedstore.
224 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
225 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
226 public:
228 
230  matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
231  typename LoadOrStoreOp::Adaptor adaptor,
232  ConversionPatternRewriter &rewriter) const override {
233  // Only 1-D vectors can be lowered to LLVM.
234  VectorType vectorTy = loadOrStoreOp.getVectorType();
235  if (vectorTy.getRank() > 1)
236  return failure();
237 
238  auto loc = loadOrStoreOp->getLoc();
239  MemRefType memRefTy = loadOrStoreOp.getMemRefType();
240 
241  // Resolve alignment.
242  unsigned align;
243  if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
244  return failure();
245 
246  // Resolve address.
247  auto vtype = cast<VectorType>(
248  this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
249  Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
250  adaptor.getIndices(), rewriter);
251  replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
252  rewriter);
253  return success();
254  }
255 };
256 
257 /// Conversion pattern for a vector.gather.
258 class VectorGatherOpConversion
259  : public ConvertOpToLLVMPattern<vector::GatherOp> {
260 public:
262 
264  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const override {
266  MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
267  assert(memRefType && "The base should be bufferized");
268 
269  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
270  return failure();
271 
272  auto loc = gather->getLoc();
273 
274  // Resolve alignment.
275  unsigned align;
276  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
277  return failure();
278 
279  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
280  adaptor.getIndices(), rewriter);
281  Value base = adaptor.getBase();
282 
283  auto llvmNDVectorTy = adaptor.getIndexVec().getType();
284  // Handle the simple case of 1-D vector.
285  if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
286  auto vType = gather.getVectorType();
287  // Resolve address.
288  Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(),
289  memRefType, base, ptr, adaptor.getIndexVec(),
290  /*vLen=*/vType.getDimSize(0));
291  // Replace with the gather intrinsic.
292  rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
293  gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
294  adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
295  return success();
296  }
297 
298  const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
299  auto callback = [align, memRefType, base, ptr, loc, &rewriter,
300  &typeConverter](Type llvm1DVectorTy,
301  ValueRange vectorOperands) {
302  // Resolve address.
303  Value ptrs = getIndexedPtrs(
304  rewriter, loc, typeConverter, memRefType, base, ptr,
305  /*index=*/vectorOperands[0],
306  LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue());
307  // Create the gather intrinsic.
308  return rewriter.create<LLVM::masked_gather>(
309  loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
310  /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
311  };
312  SmallVector<Value> vectorOperands = {
313  adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
315  gather, vectorOperands, *getTypeConverter(), callback, rewriter);
316  }
317 };
318 
319 /// Conversion pattern for a vector.scatter.
320 class VectorScatterOpConversion
321  : public ConvertOpToLLVMPattern<vector::ScatterOp> {
322 public:
324 
326  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
327  ConversionPatternRewriter &rewriter) const override {
328  auto loc = scatter->getLoc();
329  MemRefType memRefType = scatter.getMemRefType();
330 
331  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
332  return failure();
333 
334  // Resolve alignment.
335  unsigned align;
336  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
337  return failure();
338 
339  // Resolve address.
340  VectorType vType = scatter.getVectorType();
341  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
342  adaptor.getIndices(), rewriter);
343  Value ptrs = getIndexedPtrs(
344  rewriter, loc, *this->getTypeConverter(), memRefType, adaptor.getBase(),
345  ptr, adaptor.getIndexVec(), /*vLen=*/vType.getDimSize(0));
346 
347  // Replace with the scatter intrinsic.
348  rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
349  scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
350  rewriter.getI32IntegerAttr(align));
351  return success();
352  }
353 };
354 
355 /// Conversion pattern for a vector.expandload.
356 class VectorExpandLoadOpConversion
357  : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
358 public:
360 
362  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
363  ConversionPatternRewriter &rewriter) const override {
364  auto loc = expand->getLoc();
365  MemRefType memRefType = expand.getMemRefType();
366 
367  // Resolve address.
368  auto vtype = typeConverter->convertType(expand.getVectorType());
369  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
370  adaptor.getIndices(), rewriter);
371 
372  rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
373  expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
374  return success();
375  }
376 };
377 
378 /// Conversion pattern for a vector.compressstore.
379 class VectorCompressStoreOpConversion
380  : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
381 public:
383 
385  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
386  ConversionPatternRewriter &rewriter) const override {
387  auto loc = compress->getLoc();
388  MemRefType memRefType = compress.getMemRefType();
389 
390  // Resolve address.
391  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
392  adaptor.getIndices(), rewriter);
393 
394  rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
395  compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
396  return success();
397  }
398 };
399 
400 /// Reduction neutral classes for overloading.
401 class ReductionNeutralZero {};
402 class ReductionNeutralIntOne {};
403 class ReductionNeutralFPOne {};
404 class ReductionNeutralAllOnes {};
405 class ReductionNeutralSIntMin {};
406 class ReductionNeutralUIntMin {};
407 class ReductionNeutralSIntMax {};
408 class ReductionNeutralUIntMax {};
409 class ReductionNeutralFPMin {};
410 class ReductionNeutralFPMax {};
411 
412 /// Create the reduction neutral zero value.
413 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
414  ConversionPatternRewriter &rewriter,
415  Location loc, Type llvmType) {
416  return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
417  rewriter.getZeroAttr(llvmType));
418 }
419 
420 /// Create the reduction neutral integer one value.
421 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
422  ConversionPatternRewriter &rewriter,
423  Location loc, Type llvmType) {
424  return rewriter.create<LLVM::ConstantOp>(
425  loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
426 }
427 
428 /// Create the reduction neutral fp one value.
429 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
430  ConversionPatternRewriter &rewriter,
431  Location loc, Type llvmType) {
432  return rewriter.create<LLVM::ConstantOp>(
433  loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
434 }
435 
436 /// Create the reduction neutral all-ones value.
437 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
438  ConversionPatternRewriter &rewriter,
439  Location loc, Type llvmType) {
440  return rewriter.create<LLVM::ConstantOp>(
441  loc, llvmType,
442  rewriter.getIntegerAttr(
443  llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
444 }
445 
446 /// Create the reduction neutral signed int minimum value.
447 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
448  ConversionPatternRewriter &rewriter,
449  Location loc, Type llvmType) {
450  return rewriter.create<LLVM::ConstantOp>(
451  loc, llvmType,
452  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
453  llvmType.getIntOrFloatBitWidth())));
454 }
455 
456 /// Create the reduction neutral unsigned int minimum value.
457 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
458  ConversionPatternRewriter &rewriter,
459  Location loc, Type llvmType) {
460  return rewriter.create<LLVM::ConstantOp>(
461  loc, llvmType,
462  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
463  llvmType.getIntOrFloatBitWidth())));
464 }
465 
466 /// Create the reduction neutral signed int maximum value.
467 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
468  ConversionPatternRewriter &rewriter,
469  Location loc, Type llvmType) {
470  return rewriter.create<LLVM::ConstantOp>(
471  loc, llvmType,
472  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
473  llvmType.getIntOrFloatBitWidth())));
474 }
475 
476 /// Create the reduction neutral unsigned int maximum value.
477 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
478  ConversionPatternRewriter &rewriter,
479  Location loc, Type llvmType) {
480  return rewriter.create<LLVM::ConstantOp>(
481  loc, llvmType,
482  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
483  llvmType.getIntOrFloatBitWidth())));
484 }
485 
486 /// Create the reduction neutral fp minimum value.
487 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
488  ConversionPatternRewriter &rewriter,
489  Location loc, Type llvmType) {
490  auto floatType = cast<FloatType>(llvmType);
491  return rewriter.create<LLVM::ConstantOp>(
492  loc, llvmType,
493  rewriter.getFloatAttr(
494  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
495  /*Negative=*/false)));
496 }
497 
498 /// Create the reduction neutral fp maximum value.
499 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
500  ConversionPatternRewriter &rewriter,
501  Location loc, Type llvmType) {
502  auto floatType = cast<FloatType>(llvmType);
503  return rewriter.create<LLVM::ConstantOp>(
504  loc, llvmType,
505  rewriter.getFloatAttr(
506  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
507  /*Negative=*/true)));
508 }
509 
510 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
511 /// returns a new accumulator value using `ReductionNeutral`.
512 template <class ReductionNeutral>
513 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
514  Location loc, Type llvmType,
515  Value accumulator) {
516  if (accumulator)
517  return accumulator;
518 
519  return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
520  llvmType);
521 }
522 
523 /// Creates a constant value with the 1-D vector shape provided in `llvmType`.
524 /// This is used as effective vector length by some intrinsics supporting
525 /// dynamic vector lengths at runtime.
526 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
527  Location loc, Type llvmType) {
528  VectorType vType = cast<VectorType>(llvmType);
529  auto vShape = vType.getShape();
530  assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
531 
532  return rewriter.create<LLVM::ConstantOp>(
533  loc, rewriter.getI32Type(),
534  rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
535 }
536 
537 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
538 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
539 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
540 /// non-null.
541 template <class LLVMRedIntrinOp, class ScalarOp>
542 static Value createIntegerReductionArithmeticOpLowering(
543  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
544  Value vectorOperand, Value accumulator) {
545 
546  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
547 
548  if (accumulator)
549  result = rewriter.create<ScalarOp>(loc, accumulator, result);
550  return result;
551 }
552 
553 /// Helper method to lower a `vector.reduction` operation that performs
554 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
555 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
556 /// the accumulator value if non-null.
557 template <class LLVMRedIntrinOp>
558 static Value createIntegerReductionComparisonOpLowering(
559  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
560  Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
561  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
562  if (accumulator) {
563  Value cmp =
564  rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
565  result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
566  }
567  return result;
568 }
569 
570 namespace {
571 template <typename Source>
572 struct VectorToScalarMapper;
573 template <>
574 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
575  using Type = LLVM::MaximumOp;
576 };
577 template <>
578 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
579  using Type = LLVM::MinimumOp;
580 };
581 template <>
582 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
583  using Type = LLVM::MaxNumOp;
584 };
585 template <>
586 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
587  using Type = LLVM::MinNumOp;
588 };
589 } // namespace
590 
591 template <class LLVMRedIntrinOp>
592 static Value createFPReductionComparisonOpLowering(
593  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
594  Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
595  Value result =
596  rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
597 
598  if (accumulator) {
599  result =
600  rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
601  loc, result, accumulator);
602  }
603 
604  return result;
605 }
606 
607 /// Reduction neutral classes for overloading
608 class MaskNeutralFMaximum {};
609 class MaskNeutralFMinimum {};
610 
611 /// Get the mask neutral floating point maximum value
612 static llvm::APFloat
613 getMaskNeutralValue(MaskNeutralFMaximum,
614  const llvm::fltSemantics &floatSemantics) {
615  return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
616 }
617 /// Get the mask neutral floating point minimum value
618 static llvm::APFloat
619 getMaskNeutralValue(MaskNeutralFMinimum,
620  const llvm::fltSemantics &floatSemantics) {
621  return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
622 }
623 
624 /// Create the mask neutral floating point MLIR vector constant
625 template <typename MaskNeutral>
626 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
627  Location loc, Type llvmType,
628  Type vectorType) {
629  const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
630  auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
631  auto denseValue =
632  DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
633  return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
634 }
635 
636 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
637 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
638 /// `fmaximum`/`fminimum`.
639 /// More information: https://github.com/llvm/llvm-project/issues/64940
640 template <class LLVMRedIntrinOp, class MaskNeutral>
641 static Value
642 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
643  Location loc, Type llvmType,
644  Value vectorOperand, Value accumulator,
645  Value mask, LLVM::FastmathFlagsAttr fmf) {
646  const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
647  rewriter, loc, llvmType, vectorOperand.getType());
648  const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
649  loc, mask, vectorOperand, vectorMaskNeutral);
650  return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
651  rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
652 }
653 
654 template <class LLVMRedIntrinOp, class ReductionNeutral>
655 static Value
656 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
657  Type llvmType, Value vectorOperand,
658  Value accumulator, LLVM::FastmathFlagsAttr fmf) {
659  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
660  llvmType, accumulator);
661  return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
662  /*startValue=*/accumulator,
663  vectorOperand, fmf);
664 }
665 
666 /// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic
667 /// that requires a start value. This start value format spans across fp
668 /// reductions without mask and all the masked reduction intrinsics.
669 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
670 static Value
671 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
672  Location loc, Type llvmType,
673  Value vectorOperand, Value accumulator) {
674  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
675  llvmType, accumulator);
676  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
677  /*startValue=*/accumulator,
678  vectorOperand);
679 }
680 
681 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
682 static Value lowerPredicatedReductionWithStartValue(
683  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
684  Value vectorOperand, Value accumulator, Value mask) {
685  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
686  llvmType, accumulator);
687  Value vectorLength =
688  createVectorLengthValue(rewriter, loc, vectorOperand.getType());
689  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
690  /*startValue=*/accumulator,
691  vectorOperand, mask, vectorLength);
692 }
693 
694 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
695  class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
696 static Value lowerPredicatedReductionWithStartValue(
697  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
698  Value vectorOperand, Value accumulator, Value mask) {
699  if (llvmType.isIntOrIndex())
700  return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
701  IntReductionNeutral>(
702  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
703 
704  // FP dispatch.
705  return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
706  FPReductionNeutral>(
707  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
708 }
709 
710 /// Conversion pattern for all vector reductions.
711 class VectorReductionOpConversion
712  : public ConvertOpToLLVMPattern<vector::ReductionOp> {
713 public:
714  explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
715  bool reassociateFPRed)
716  : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
717  reassociateFPReductions(reassociateFPRed) {}
718 
720  matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
721  ConversionPatternRewriter &rewriter) const override {
722  auto kind = reductionOp.getKind();
723  Type eltType = reductionOp.getDest().getType();
724  Type llvmType = typeConverter->convertType(eltType);
725  Value operand = adaptor.getVector();
726  Value acc = adaptor.getAcc();
727  Location loc = reductionOp.getLoc();
728 
729  if (eltType.isIntOrIndex()) {
730  // Integer reductions: add/mul/min/max/and/or/xor.
731  Value result;
732  switch (kind) {
733  case vector::CombiningKind::ADD:
734  result =
735  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
736  LLVM::AddOp>(
737  rewriter, loc, llvmType, operand, acc);
738  break;
739  case vector::CombiningKind::MUL:
740  result =
741  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
742  LLVM::MulOp>(
743  rewriter, loc, llvmType, operand, acc);
744  break;
745  case vector::CombiningKind::MINUI:
746  result = createIntegerReductionComparisonOpLowering<
747  LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
748  LLVM::ICmpPredicate::ule);
749  break;
750  case vector::CombiningKind::MINSI:
751  result = createIntegerReductionComparisonOpLowering<
752  LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
753  LLVM::ICmpPredicate::sle);
754  break;
755  case vector::CombiningKind::MAXUI:
756  result = createIntegerReductionComparisonOpLowering<
757  LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
758  LLVM::ICmpPredicate::uge);
759  break;
760  case vector::CombiningKind::MAXSI:
761  result = createIntegerReductionComparisonOpLowering<
762  LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
763  LLVM::ICmpPredicate::sge);
764  break;
765  case vector::CombiningKind::AND:
766  result =
767  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
768  LLVM::AndOp>(
769  rewriter, loc, llvmType, operand, acc);
770  break;
771  case vector::CombiningKind::OR:
772  result =
773  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
774  LLVM::OrOp>(
775  rewriter, loc, llvmType, operand, acc);
776  break;
777  case vector::CombiningKind::XOR:
778  result =
779  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
780  LLVM::XOrOp>(
781  rewriter, loc, llvmType, operand, acc);
782  break;
783  default:
784  return failure();
785  }
786  rewriter.replaceOp(reductionOp, result);
787 
788  return success();
789  }
790 
791  if (!isa<FloatType>(eltType))
792  return failure();
793 
794  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
795  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
796  reductionOp.getContext(),
797  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
799  reductionOp.getContext(),
800  fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
801  : LLVM::FastmathFlags::none));
802 
803  // Floating-point reductions: add/mul/min/max
804  Value result;
805  if (kind == vector::CombiningKind::ADD) {
806  result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
807  ReductionNeutralZero>(
808  rewriter, loc, llvmType, operand, acc, fmf);
809  } else if (kind == vector::CombiningKind::MUL) {
810  result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
811  ReductionNeutralFPOne>(
812  rewriter, loc, llvmType, operand, acc, fmf);
813  } else if (kind == vector::CombiningKind::MINIMUMF) {
814  result =
815  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
816  rewriter, loc, llvmType, operand, acc, fmf);
817  } else if (kind == vector::CombiningKind::MAXIMUMF) {
818  result =
819  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
820  rewriter, loc, llvmType, operand, acc, fmf);
821  } else if (kind == vector::CombiningKind::MINF) {
822  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
823  rewriter, loc, llvmType, operand, acc, fmf);
824  } else if (kind == vector::CombiningKind::MAXF) {
825  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
826  rewriter, loc, llvmType, operand, acc, fmf);
827  } else
828  return failure();
829 
830  rewriter.replaceOp(reductionOp, result);
831  return success();
832  }
833 
834 private:
835  const bool reassociateFPReductions;
836 };
837 
838 /// Base class to convert a `vector.mask` operation while matching traits
839 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
840 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
841 /// method performs a second match against the maskable operation `MaskedOp`.
842 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
843 /// implemented by the concrete conversion classes. This method can match
844 /// against specific traits of the `vector.mask` and the maskable operation. It
845 /// must replace the `vector.mask` operation.
846 template <class MaskedOp>
847 class VectorMaskOpConversionBase
848  : public ConvertOpToLLVMPattern<vector::MaskOp> {
849 public:
851 
853  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
854  ConversionPatternRewriter &rewriter) const final {
855  // Match against the maskable operation kind.
856  auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
857  if (!maskedOp)
858  return failure();
859  return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
860  }
861 
862 protected:
863  virtual LogicalResult
864  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
865  vector::MaskableOpInterface maskableOp,
866  ConversionPatternRewriter &rewriter) const = 0;
867 };
868 
869 class MaskedReductionOpConversion
870  : public VectorMaskOpConversionBase<vector::ReductionOp> {
871 
872 public:
873  using VectorMaskOpConversionBase<
874  vector::ReductionOp>::VectorMaskOpConversionBase;
875 
876  LogicalResult matchAndRewriteMaskableOp(
877  vector::MaskOp maskOp, MaskableOpInterface maskableOp,
878  ConversionPatternRewriter &rewriter) const override {
879  auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
880  auto kind = reductionOp.getKind();
881  Type eltType = reductionOp.getDest().getType();
882  Type llvmType = typeConverter->convertType(eltType);
883  Value operand = reductionOp.getVector();
884  Value acc = reductionOp.getAcc();
885  Location loc = reductionOp.getLoc();
886 
887  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
888  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
889  reductionOp.getContext(),
890  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
891 
892  Value result;
893  switch (kind) {
894  case vector::CombiningKind::ADD:
895  result = lowerPredicatedReductionWithStartValue<
896  LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
897  ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
898  maskOp.getMask());
899  break;
900  case vector::CombiningKind::MUL:
901  result = lowerPredicatedReductionWithStartValue<
902  LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
903  ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
904  maskOp.getMask());
905  break;
906  case vector::CombiningKind::MINUI:
907  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
908  ReductionNeutralUIntMax>(
909  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
910  break;
911  case vector::CombiningKind::MINSI:
912  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
913  ReductionNeutralSIntMax>(
914  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
915  break;
916  case vector::CombiningKind::MAXUI:
917  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
918  ReductionNeutralUIntMin>(
919  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
920  break;
921  case vector::CombiningKind::MAXSI:
922  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
923  ReductionNeutralSIntMin>(
924  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
925  break;
926  case vector::CombiningKind::AND:
927  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
928  ReductionNeutralAllOnes>(
929  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
930  break;
931  case vector::CombiningKind::OR:
932  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
933  ReductionNeutralZero>(
934  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
935  break;
936  case vector::CombiningKind::XOR:
937  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
938  ReductionNeutralZero>(
939  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
940  break;
941  case vector::CombiningKind::MINF:
942  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
943  ReductionNeutralFPMax>(
944  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
945  break;
946  case vector::CombiningKind::MAXF:
947  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
948  ReductionNeutralFPMin>(
949  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
950  break;
951  case CombiningKind::MAXIMUMF:
952  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
953  MaskNeutralFMaximum>(
954  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
955  break;
956  case CombiningKind::MINIMUMF:
957  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
958  MaskNeutralFMinimum>(
959  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
960  break;
961  }
962 
963  // Replace `vector.mask` operation altogether.
964  rewriter.replaceOp(maskOp, result);
965  return success();
966  }
967 };
968 
969 class VectorShuffleOpConversion
970  : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
971 public:
973 
975  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
976  ConversionPatternRewriter &rewriter) const override {
977  auto loc = shuffleOp->getLoc();
978  auto v1Type = shuffleOp.getV1VectorType();
979  auto v2Type = shuffleOp.getV2VectorType();
980  auto vectorType = shuffleOp.getResultVectorType();
981  Type llvmType = typeConverter->convertType(vectorType);
982  auto maskArrayAttr = shuffleOp.getMask();
983 
984  // Bail if result type cannot be lowered.
985  if (!llvmType)
986  return failure();
987 
988  // Get rank and dimension sizes.
989  int64_t rank = vectorType.getRank();
990 #ifndef NDEBUG
991  bool wellFormed0DCase =
992  v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
993  bool wellFormedNDCase =
994  v1Type.getRank() == rank && v2Type.getRank() == rank;
995  assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
996 #endif
997 
998  // For rank 0 and 1, where both operands have *exactly* the same vector
999  // type, there is direct shuffle support in LLVM. Use it!
1000  if (rank <= 1 && v1Type == v2Type) {
1001  Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1002  loc, adaptor.getV1(), adaptor.getV2(),
1003  LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1004  rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1005  return success();
1006  }
1007 
1008  // For all other cases, insert the individual values individually.
1009  int64_t v1Dim = v1Type.getDimSize(0);
1010  Type eltType;
1011  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1012  eltType = arrayType.getElementType();
1013  else
1014  eltType = cast<VectorType>(llvmType).getElementType();
1015  Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
1016  int64_t insPos = 0;
1017  for (const auto &en : llvm::enumerate(maskArrayAttr)) {
1018  int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1019  Value value = adaptor.getV1();
1020  if (extPos >= v1Dim) {
1021  extPos -= v1Dim;
1022  value = adaptor.getV2();
1023  }
1024  Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1025  eltType, rank, extPos);
1026  insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1027  llvmType, rank, insPos++);
1028  }
1029  rewriter.replaceOp(shuffleOp, insert);
1030  return success();
1031  }
1032 };
1033 
1034 class VectorExtractElementOpConversion
1035  : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1036 public:
1037  using ConvertOpToLLVMPattern<
1038  vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1039 
1041  matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1042  ConversionPatternRewriter &rewriter) const override {
1043  auto vectorType = extractEltOp.getSourceVectorType();
1044  auto llvmType = typeConverter->convertType(vectorType.getElementType());
1045 
1046  // Bail if result type cannot be lowered.
1047  if (!llvmType)
1048  return failure();
1049 
1050  if (vectorType.getRank() == 0) {
1051  Location loc = extractEltOp.getLoc();
1052  auto idxType = rewriter.getIndexType();
1053  auto zero = rewriter.create<LLVM::ConstantOp>(
1054  loc, typeConverter->convertType(idxType),
1055  rewriter.getIntegerAttr(idxType, 0));
1056  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1057  extractEltOp, llvmType, adaptor.getVector(), zero);
1058  return success();
1059  }
1060 
1061  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1062  extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1063  return success();
1064  }
1065 };
1066 
1067 class VectorExtractOpConversion
1068  : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1069 public:
1071 
1073  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1074  ConversionPatternRewriter &rewriter) const override {
1075  auto loc = extractOp->getLoc();
1076  auto resultType = extractOp.getResult().getType();
1077  auto llvmResultType = typeConverter->convertType(resultType);
1078  // Bail if result type cannot be lowered.
1079  if (!llvmResultType)
1080  return failure();
1081 
1082  SmallVector<OpFoldResult> positionVec;
1083  for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
1084  if (pos.is<Value>())
1085  // Make sure we use the value that has been already converted to LLVM.
1086  positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1087  else
1088  positionVec.push_back(pos);
1089  }
1090 
1091  // Extract entire vector. Should be handled by folder, but just to be safe.
1092  ArrayRef<OpFoldResult> position(positionVec);
1093  if (position.empty()) {
1094  rewriter.replaceOp(extractOp, adaptor.getVector());
1095  return success();
1096  }
1097 
1098  // One-shot extraction of vector from array (only requires extractvalue).
1099  if (isa<VectorType>(resultType)) {
1100  if (extractOp.hasDynamicPosition())
1101  return failure();
1102 
1103  Value extracted = rewriter.create<LLVM::ExtractValueOp>(
1104  loc, adaptor.getVector(), getAsIntegers(position));
1105  rewriter.replaceOp(extractOp, extracted);
1106  return success();
1107  }
1108 
1109  // Potential extraction of 1-D vector from array.
1110  Value extracted = adaptor.getVector();
1111  if (position.size() > 1) {
1112  if (extractOp.hasDynamicPosition())
1113  return failure();
1114 
1115  SmallVector<int64_t> nMinusOnePosition =
1116  getAsIntegers(position.drop_back());
1117  extracted = rewriter.create<LLVM::ExtractValueOp>(loc, extracted,
1118  nMinusOnePosition);
1119  }
1120 
1121  Value lastPosition = getAsLLVMValue(rewriter, loc, position.back());
1122  // Remaining extraction of element from 1-D LLVM vector.
1123  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(extractOp, extracted,
1124  lastPosition);
1125  return success();
1126  }
1127 };
1128 
1129 /// Conversion pattern that turns a vector.fma on a 1-D vector
1130 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1131 /// This does not match vectors of n >= 2 rank.
1132 ///
1133 /// Example:
1134 /// ```
1135 /// vector.fma %a, %a, %a : vector<8xf32>
1136 /// ```
1137 /// is converted to:
1138 /// ```
1139 /// llvm.intr.fmuladd %va, %va, %va:
1140 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1141 /// -> !llvm."<8 x f32>">
1142 /// ```
1143 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1144 public:
1146 
1148  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1149  ConversionPatternRewriter &rewriter) const override {
1150  VectorType vType = fmaOp.getVectorType();
1151  if (vType.getRank() > 1)
1152  return failure();
1153 
1154  rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1155  fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1156  return success();
1157  }
1158 };
1159 
1160 class VectorInsertElementOpConversion
1161  : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1162 public:
1164 
1166  matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1167  ConversionPatternRewriter &rewriter) const override {
1168  auto vectorType = insertEltOp.getDestVectorType();
1169  auto llvmType = typeConverter->convertType(vectorType);
1170 
1171  // Bail if result type cannot be lowered.
1172  if (!llvmType)
1173  return failure();
1174 
1175  if (vectorType.getRank() == 0) {
1176  Location loc = insertEltOp.getLoc();
1177  auto idxType = rewriter.getIndexType();
1178  auto zero = rewriter.create<LLVM::ConstantOp>(
1179  loc, typeConverter->convertType(idxType),
1180  rewriter.getIntegerAttr(idxType, 0));
1181  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1182  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1183  return success();
1184  }
1185 
1186  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1187  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1188  adaptor.getPosition());
1189  return success();
1190  }
1191 };
1192 
1193 class VectorInsertOpConversion
1194  : public ConvertOpToLLVMPattern<vector::InsertOp> {
1195 public:
1197 
1199  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1200  ConversionPatternRewriter &rewriter) const override {
1201  auto loc = insertOp->getLoc();
1202  auto sourceType = insertOp.getSourceType();
1203  auto destVectorType = insertOp.getDestVectorType();
1204  auto llvmResultType = typeConverter->convertType(destVectorType);
1205  // Bail if result type cannot be lowered.
1206  if (!llvmResultType)
1207  return failure();
1208 
1209  SmallVector<OpFoldResult> positionVec;
1210  for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
1211  if (pos.is<Value>())
1212  // Make sure we use the value that has been already converted to LLVM.
1213  positionVec.push_back(adaptor.getDynamicPosition()[idx]);
1214  else
1215  positionVec.push_back(pos);
1216  }
1217 
1218  // Overwrite entire vector with value. Should be handled by folder, but
1219  // just to be safe.
1220  ArrayRef<OpFoldResult> position(positionVec);
1221  if (position.empty()) {
1222  rewriter.replaceOp(insertOp, adaptor.getSource());
1223  return success();
1224  }
1225 
1226  // One-shot insertion of a vector into an array (only requires insertvalue).
1227  if (isa<VectorType>(sourceType)) {
1228  if (insertOp.hasDynamicPosition())
1229  return failure();
1230 
1231  Value inserted = rewriter.create<LLVM::InsertValueOp>(
1232  loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1233  rewriter.replaceOp(insertOp, inserted);
1234  return success();
1235  }
1236 
1237  // Potential extraction of 1-D vector from array.
1238  Value extracted = adaptor.getDest();
1239  auto oneDVectorType = destVectorType;
1240  if (position.size() > 1) {
1241  if (insertOp.hasDynamicPosition())
1242  return failure();
1243 
1244  oneDVectorType = reducedVectorTypeBack(destVectorType);
1245  extracted = rewriter.create<LLVM::ExtractValueOp>(
1246  loc, extracted, getAsIntegers(position.drop_back()));
1247  }
1248 
1249  // Insertion of an element into a 1-D LLVM vector.
1250  Value inserted = rewriter.create<LLVM::InsertElementOp>(
1251  loc, typeConverter->convertType(oneDVectorType), extracted,
1252  adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1253 
1254  // Potential insertion of resulting 1-D vector into array.
1255  if (position.size() > 1) {
1256  if (insertOp.hasDynamicPosition())
1257  return failure();
1258 
1259  inserted = rewriter.create<LLVM::InsertValueOp>(
1260  loc, adaptor.getDest(), inserted,
1261  getAsIntegers(position.drop_back()));
1262  }
1263 
1264  rewriter.replaceOp(insertOp, inserted);
1265  return success();
1266  }
1267 };
1268 
1269 /// Lower vector.scalable.insert ops to LLVM vector.insert
1270 struct VectorScalableInsertOpLowering
1271  : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1272  using ConvertOpToLLVMPattern<
1273  vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1274 
1276  matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1277  ConversionPatternRewriter &rewriter) const override {
1278  rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1279  insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1280  return success();
1281  }
1282 };
1283 
1284 /// Lower vector.scalable.extract ops to LLVM vector.extract
1285 struct VectorScalableExtractOpLowering
1286  : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1287  using ConvertOpToLLVMPattern<
1288  vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1289 
1291  matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1292  ConversionPatternRewriter &rewriter) const override {
1293  rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1294  extOp, typeConverter->convertType(extOp.getResultVectorType()),
1295  adaptor.getSource(), adaptor.getPos());
1296  return success();
1297  }
1298 };
1299 
1300 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1301 ///
1302 /// Example:
1303 /// ```
1304 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1305 /// ```
1306 /// is rewritten into:
1307 /// ```
1308 /// %r = splat %f0: vector<2x4xf32>
1309 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1310 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1311 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1312 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1313 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1314 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1315 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1316 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1317 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1318 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1319 /// // %r3 holds the final value.
1320 /// ```
1321 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1322 public:
1324 
1325  void initialize() {
1326  // This pattern recursively unpacks one dimension at a time. The recursion
1327  // bounded as the rank is strictly decreasing.
1328  setHasBoundedRewriteRecursion();
1329  }
1330 
1331  LogicalResult matchAndRewrite(FMAOp op,
1332  PatternRewriter &rewriter) const override {
1333  auto vType = op.getVectorType();
1334  if (vType.getRank() < 2)
1335  return failure();
1336 
1337  auto loc = op.getLoc();
1338  auto elemType = vType.getElementType();
1339  Value zero = rewriter.create<arith::ConstantOp>(
1340  loc, elemType, rewriter.getZeroAttr(elemType));
1341  Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1342  for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1343  Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1344  Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1345  Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1346  Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1347  desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1348  }
1349  rewriter.replaceOp(op, desc);
1350  return success();
1351  }
1352 };
1353 
1354 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1355 /// static layout.
1356 static std::optional<SmallVector<int64_t, 4>>
1357 computeContiguousStrides(MemRefType memRefType) {
1358  int64_t offset;
1359  SmallVector<int64_t, 4> strides;
1360  if (failed(getStridesAndOffset(memRefType, strides, offset)))
1361  return std::nullopt;
1362  if (!strides.empty() && strides.back() != 1)
1363  return std::nullopt;
1364  // If no layout or identity layout, this is contiguous by definition.
1365  if (memRefType.getLayout().isIdentity())
1366  return strides;
1367 
1368  // Otherwise, we must determine contiguity form shapes. This can only ever
1369  // work in static cases because MemRefType is underspecified to represent
1370  // contiguous dynamic shapes in other ways than with just empty/identity
1371  // layout.
1372  auto sizes = memRefType.getShape();
1373  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1374  if (ShapedType::isDynamic(sizes[index + 1]) ||
1375  ShapedType::isDynamic(strides[index]) ||
1376  ShapedType::isDynamic(strides[index + 1]))
1377  return std::nullopt;
1378  if (strides[index] != strides[index + 1] * sizes[index + 1])
1379  return std::nullopt;
1380  }
1381  return strides;
1382 }
1383 
1384 class VectorTypeCastOpConversion
1385  : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1386 public:
1388 
1390  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1391  ConversionPatternRewriter &rewriter) const override {
1392  auto loc = castOp->getLoc();
1393  MemRefType sourceMemRefType =
1394  cast<MemRefType>(castOp.getOperand().getType());
1395  MemRefType targetMemRefType = castOp.getType();
1396 
1397  // Only static shape casts supported atm.
1398  if (!sourceMemRefType.hasStaticShape() ||
1399  !targetMemRefType.hasStaticShape())
1400  return failure();
1401 
1402  auto llvmSourceDescriptorTy =
1403  dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1404  if (!llvmSourceDescriptorTy)
1405  return failure();
1406  MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1407 
1408  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1409  typeConverter->convertType(targetMemRefType));
1410  if (!llvmTargetDescriptorTy)
1411  return failure();
1412 
1413  // Only contiguous source buffers supported atm.
1414  auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1415  if (!sourceStrides)
1416  return failure();
1417  auto targetStrides = computeContiguousStrides(targetMemRefType);
1418  if (!targetStrides)
1419  return failure();
1420  // Only support static strides for now, regardless of contiguity.
1421  if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1422  return failure();
1423 
1424  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1425 
1426  // Create descriptor.
1427  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
1428  // Set allocated ptr.
1429  Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1430  desc.setAllocatedPtr(rewriter, loc, allocated);
1431 
1432  // Set aligned ptr.
1433  Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1434  desc.setAlignedPtr(rewriter, loc, ptr);
1435  // Fill offset 0.
1436  auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1437  auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1438  desc.setOffset(rewriter, loc, zero);
1439 
1440  // Fill size and stride descriptors in memref.
1441  for (const auto &indexedSize :
1442  llvm::enumerate(targetMemRefType.getShape())) {
1443  int64_t index = indexedSize.index();
1444  auto sizeAttr =
1445  rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1446  auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1447  desc.setSize(rewriter, loc, index, size);
1448  auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1449  (*targetStrides)[index]);
1450  auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1451  desc.setStride(rewriter, loc, index, stride);
1452  }
1453 
1454  rewriter.replaceOp(castOp, {desc});
1455  return success();
1456  }
1457 };
1458 
1459 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1460 /// Non-scalable versions of this operation are handled in Vector Transforms.
1461 class VectorCreateMaskOpRewritePattern
1462  : public OpRewritePattern<vector::CreateMaskOp> {
1463 public:
1464  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
1465  bool enableIndexOpt)
1466  : OpRewritePattern<vector::CreateMaskOp>(context),
1467  force32BitVectorIndices(enableIndexOpt) {}
1468 
1469  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1470  PatternRewriter &rewriter) const override {
1471  auto dstType = op.getType();
1472  if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1473  return failure();
1474  IntegerType idxType =
1475  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1476  auto loc = op->getLoc();
1477  Value indices = rewriter.create<LLVM::StepVectorOp>(
1478  loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1479  /*isScalable=*/true));
1480  auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1481  op.getOperand(0));
1482  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1483  Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1484  indices, bounds);
1485  rewriter.replaceOp(op, comp);
1486  return success();
1487  }
1488 
1489 private:
1490  const bool force32BitVectorIndices;
1491 };
1492 
1493 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1494 public:
1496 
1497  // Lowering implementation that relies on a small runtime support library,
1498  // which only needs to provide a few printing methods (single value for all
1499  // data types, opening/closing bracket, comma, newline). The lowering splits
1500  // the vector into elementary printing operations. The advantage of this
1501  // approach is that the library can remain unaware of all low-level
1502  // implementation details of vectors while still supporting output of any
1503  // shaped and dimensioned vector.
1504  //
1505  // Note: This lowering only handles scalars, n-D vectors are broken into
1506  // printing scalars in loops in VectorToSCF.
1507  //
1508  // TODO: rely solely on libc in future? something else?
1509  //
1511  matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1512  ConversionPatternRewriter &rewriter) const override {
1513  auto parent = printOp->getParentOfType<ModuleOp>();
1514  if (!parent)
1515  return failure();
1516 
1517  auto loc = printOp->getLoc();
1518 
1519  if (auto value = adaptor.getSource()) {
1520  Type printType = printOp.getPrintType();
1521  if (isa<VectorType>(printType)) {
1522  // Vectors should be broken into elementary print ops in VectorToSCF.
1523  return failure();
1524  }
1525  if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1526  return failure();
1527  }
1528 
1529  auto punct = printOp.getPunctuation();
1530  if (auto stringLiteral = printOp.getStringLiteral()) {
1531  LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1532  *stringLiteral, *getTypeConverter());
1533  } else if (punct != PrintPunctuation::NoPunctuation) {
1534  emitCall(rewriter, printOp->getLoc(), [&] {
1535  switch (punct) {
1536  case PrintPunctuation::Close:
1537  return LLVM::lookupOrCreatePrintCloseFn(parent);
1538  case PrintPunctuation::Open:
1539  return LLVM::lookupOrCreatePrintOpenFn(parent);
1540  case PrintPunctuation::Comma:
1541  return LLVM::lookupOrCreatePrintCommaFn(parent);
1542  case PrintPunctuation::NewLine:
1543  return LLVM::lookupOrCreatePrintNewlineFn(parent);
1544  default:
1545  llvm_unreachable("unexpected punctuation");
1546  }
1547  }());
1548  }
1549 
1550  rewriter.eraseOp(printOp);
1551  return success();
1552  }
1553 
1554 private:
1555  enum class PrintConversion {
1556  // clang-format off
1557  None,
1558  ZeroExt64,
1559  SignExt64,
1560  Bitcast16
1561  // clang-format on
1562  };
1563 
1564  LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1565  ModuleOp parent, Location loc, Type printType,
1566  Value value) const {
1567  if (typeConverter->convertType(printType) == nullptr)
1568  return failure();
1569 
1570  // Make sure element type has runtime support.
1571  PrintConversion conversion = PrintConversion::None;
1572  Operation *printer;
1573  if (printType.isF32()) {
1574  printer = LLVM::lookupOrCreatePrintF32Fn(parent);
1575  } else if (printType.isF64()) {
1576  printer = LLVM::lookupOrCreatePrintF64Fn(parent);
1577  } else if (printType.isF16()) {
1578  conversion = PrintConversion::Bitcast16; // bits!
1579  printer = LLVM::lookupOrCreatePrintF16Fn(parent);
1580  } else if (printType.isBF16()) {
1581  conversion = PrintConversion::Bitcast16; // bits!
1582  printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
1583  } else if (printType.isIndex()) {
1584  printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1585  } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1586  // Integers need a zero or sign extension on the operand
1587  // (depending on the source type) as well as a signed or
1588  // unsigned print method. Up to 64-bit is supported.
1589  unsigned width = intTy.getWidth();
1590  if (intTy.isUnsigned()) {
1591  if (width <= 64) {
1592  if (width < 64)
1593  conversion = PrintConversion::ZeroExt64;
1594  printer = LLVM::lookupOrCreatePrintU64Fn(parent);
1595  } else {
1596  return failure();
1597  }
1598  } else {
1599  assert(intTy.isSignless() || intTy.isSigned());
1600  if (width <= 64) {
1601  // Note that we *always* zero extend booleans (1-bit integers),
1602  // so that true/false is printed as 1/0 rather than -1/0.
1603  if (width == 1)
1604  conversion = PrintConversion::ZeroExt64;
1605  else if (width < 64)
1606  conversion = PrintConversion::SignExt64;
1607  printer = LLVM::lookupOrCreatePrintI64Fn(parent);
1608  } else {
1609  return failure();
1610  }
1611  }
1612  } else {
1613  return failure();
1614  }
1615 
1616  switch (conversion) {
1617  case PrintConversion::ZeroExt64:
1618  value = rewriter.create<arith::ExtUIOp>(
1619  loc, IntegerType::get(rewriter.getContext(), 64), value);
1620  break;
1621  case PrintConversion::SignExt64:
1622  value = rewriter.create<arith::ExtSIOp>(
1623  loc, IntegerType::get(rewriter.getContext(), 64), value);
1624  break;
1625  case PrintConversion::Bitcast16:
1626  value = rewriter.create<LLVM::BitcastOp>(
1627  loc, IntegerType::get(rewriter.getContext(), 16), value);
1628  break;
1629  case PrintConversion::None:
1630  break;
1631  }
1632  emitCall(rewriter, loc, printer, value);
1633  return success();
1634  }
1635 
1636  // Helper to emit a call.
1637  static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1638  Operation *ref, ValueRange params = ValueRange()) {
1639  rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1640  params);
1641  }
1642 };
1643 
1644 /// The Splat operation is lowered to an insertelement + a shufflevector
1645 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1646 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1648 
1650  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1651  ConversionPatternRewriter &rewriter) const override {
1652  VectorType resultType = cast<VectorType>(splatOp.getType());
1653  if (resultType.getRank() > 1)
1654  return failure();
1655 
1656  // First insert it into an undef vector so we can shuffle it.
1657  auto vectorType = typeConverter->convertType(splatOp.getType());
1658  Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1659  auto zero = rewriter.create<LLVM::ConstantOp>(
1660  splatOp.getLoc(),
1661  typeConverter->convertType(rewriter.getIntegerType(32)),
1662  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1663 
1664  // For 0-d vector, we simply do `insertelement`.
1665  if (resultType.getRank() == 0) {
1666  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1667  splatOp, vectorType, undef, adaptor.getInput(), zero);
1668  return success();
1669  }
1670 
1671  // For 1-d vector, we additionally do a `vectorshuffle`.
1672  auto v = rewriter.create<LLVM::InsertElementOp>(
1673  splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1674 
1675  int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1676  SmallVector<int32_t> zeroValues(width, 0);
1677 
1678  // Shuffle the value across the desired number of elements.
1679  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1680  zeroValues);
1681  return success();
1682  }
1683 };
1684 
1685 /// The Splat operation is lowered to an insertelement + a shufflevector
1686 /// operation. Splat to only 2+-d vector result types are lowered by the
1687 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1688 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1690 
1692  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1693  ConversionPatternRewriter &rewriter) const override {
1694  VectorType resultType = splatOp.getType();
1695  if (resultType.getRank() <= 1)
1696  return failure();
1697 
1698  // First insert it into an undef vector so we can shuffle it.
1699  auto loc = splatOp.getLoc();
1700  auto vectorTypeInfo =
1701  LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1702  auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1703  auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1704  if (!llvmNDVectorTy || !llvm1DVectorTy)
1705  return failure();
1706 
1707  // Construct returned value.
1708  Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1709 
1710  // Construct a 1-D vector with the splatted value that we insert in all the
1711  // places within the returned descriptor.
1712  Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1713  auto zero = rewriter.create<LLVM::ConstantOp>(
1714  loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1715  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1716  Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1717  adaptor.getInput(), zero);
1718 
1719  // Shuffle the value across the desired number of elements.
1720  int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1721  SmallVector<int32_t> zeroValues(width, 0);
1722  v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1723 
1724  // Iterate of linear index, convert to coords space and insert splatted 1-D
1725  // vector in each position.
1726  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1727  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1728  });
1729  rewriter.replaceOp(splatOp, desc);
1730  return success();
1731  }
1732 };
1733 
1734 } // namespace
1735 
1736 /// Populate the given list with patterns that convert from Vector to LLVM.
1738  LLVMTypeConverter &converter, RewritePatternSet &patterns,
1739  bool reassociateFPReductions, bool force32BitVectorIndices) {
1740  MLIRContext *ctx = converter.getDialect()->getContext();
1741  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1743  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1744  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1745  patterns
1746  .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1747  VectorExtractElementOpConversion, VectorExtractOpConversion,
1748  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1749  VectorInsertOpConversion, VectorPrintOpConversion,
1750  VectorTypeCastOpConversion, VectorScaleOpConversion,
1751  VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1752  VectorLoadStoreConversion<vector::MaskedLoadOp,
1753  vector::MaskedLoadOpAdaptor>,
1754  VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1755  VectorLoadStoreConversion<vector::MaskedStoreOp,
1756  vector::MaskedStoreOpAdaptor>,
1757  VectorGatherOpConversion, VectorScatterOpConversion,
1758  VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1759  VectorSplatOpLowering, VectorSplatNdOpLowering,
1760  VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
1761  MaskedReductionOpConversion>(converter);
1762  // Transfer ops with rank > 1 are handled by VectorToSCF.
1763  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1764 }
1765 
1767  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1768  patterns.add<VectorMatmulOpConversion>(converter);
1769  patterns.add<VectorFlatTransposeOpConversion>(converter);
1770 }
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, uint64_t vLen)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
static VectorType reducedVectorTypeBack(VectorType tp)
@ None
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:19
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Definition: TypeConverter.h:91
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition: TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
Definition: TypeToLLVM.cpp:196
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:194
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:113
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:915
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp)
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:942
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp)
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp)
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp)
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:888
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:282
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:49
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357