MLIR  21.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/TypeUtilities.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/Support/Casting.h"
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::vector;
37 
38 // Helper that picks the proper sequence for inserting.
40  const LLVMTypeConverter &typeConverter, Location loc,
41  Value val1, Value val2, Type llvmType, int64_t rank,
42  int64_t pos) {
43  assert(rank > 0 && "0-D vector corner case should have been handled already");
44  if (rank == 1) {
45  auto idxType = rewriter.getIndexType();
46  auto constant = rewriter.create<LLVM::ConstantOp>(
47  loc, typeConverter.convertType(idxType),
48  rewriter.getIntegerAttr(idxType, pos));
49  return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
50  constant);
51  }
52  return rewriter.create<LLVM::InsertValueOp>(loc, val1, val2, pos);
53 }
54 
55 // Helper that picks the proper sequence for extracting.
57  const LLVMTypeConverter &typeConverter, Location loc,
58  Value val, Type llvmType, int64_t rank, int64_t pos) {
59  if (rank <= 1) {
60  auto idxType = rewriter.getIndexType();
61  auto constant = rewriter.create<LLVM::ConstantOp>(
62  loc, typeConverter.convertType(idxType),
63  rewriter.getIntegerAttr(idxType, pos));
64  return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
65  constant);
66  }
67  return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
68 }
69 
70 // Helper that returns data layout alignment of a vector.
71 LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
72  VectorType vectorType, unsigned &align) {
73  Type convertedVectorTy = typeConverter.convertType(vectorType);
74  if (!convertedVectorTy)
75  return failure();
76 
77  llvm::LLVMContext llvmContext;
78  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
79  .getPreferredAlignment(convertedVectorTy,
80  typeConverter.getDataLayout());
81 
82  return success();
83 }
84 
85 // Helper that returns data layout alignment of a memref.
86 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
87  MemRefType memrefType, unsigned &align) {
88  Type elementTy = typeConverter.convertType(memrefType.getElementType());
89  if (!elementTy)
90  return failure();
91 
92  // TODO: this should use the MLIR data layout when it becomes available and
93  // stop depending on translation.
94  llvm::LLVMContext llvmContext;
95  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
96  .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
97  return success();
98 }
99 
100 // Helper to resolve the alignment for vector load/store, gather and scatter
101 // ops. If useVectorAlignment is true, get the preferred alignment for the
102 // vector type in the operation. This option is used for hardware backends with
103 // vectorization. Otherwise, use the preferred alignment of the element type of
104 // the memref. Note that if you choose to use vector alignment, the shape of the
105 // vector type must be resolved before the ConvertVectorToLLVM pass is run.
106 LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
107  VectorType vectorType,
108  MemRefType memrefType, unsigned &align,
109  bool useVectorAlignment) {
110  if (useVectorAlignment) {
111  if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
112  return failure();
113  }
114  } else {
115  if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
116  return failure();
117  }
118  }
119  return success();
120 }
121 
122 // Check if the last stride is non-unit and has a valid memory space.
123 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
124  const LLVMTypeConverter &converter) {
125  if (!memRefType.isLastDimUnitStride())
126  return failure();
127  if (failed(converter.getMemRefAddressSpace(memRefType)))
128  return failure();
129  return success();
130 }
131 
132 // Add an index vector component to a base pointer.
134  const LLVMTypeConverter &typeConverter,
135  MemRefType memRefType, Value llvmMemref, Value base,
136  Value index, VectorType vectorType) {
137  assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
138  "unsupported memref type");
139  assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
140  auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
141  auto ptrsType =
142  LLVM::getVectorType(pType, vectorType.getDimSize(0),
143  /*isScalable=*/vectorType.getScalableDims()[0]);
144  return rewriter.create<LLVM::GEPOp>(
145  loc, ptrsType, typeConverter.convertType(memRefType.getElementType()),
146  base, index);
147 }
148 
149 /// Convert `foldResult` into a Value. Integer attribute is converted to
150 /// an LLVM constant op.
151 static Value getAsLLVMValue(OpBuilder &builder, Location loc,
152  OpFoldResult foldResult) {
153  if (auto attr = dyn_cast<Attribute>(foldResult)) {
154  auto intAttr = cast<IntegerAttr>(attr);
155  return builder.create<LLVM::ConstantOp>(loc, intAttr).getResult();
156  }
157 
158  return cast<Value>(foldResult);
159 }
160 
161 namespace {
162 
163 /// Trivial Vector to LLVM conversions
164 using VectorScaleOpConversion =
166 
167 /// Conversion pattern for a vector.bitcast.
168 class VectorBitCastOpConversion
169  : public ConvertOpToLLVMPattern<vector::BitCastOp> {
170 public:
172 
173  LogicalResult
174  matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
175  ConversionPatternRewriter &rewriter) const override {
176  // Only 0-D and 1-D vectors can be lowered to LLVM.
177  VectorType resultTy = bitCastOp.getResultVectorType();
178  if (resultTy.getRank() > 1)
179  return failure();
180  Type newResultTy = typeConverter->convertType(resultTy);
181  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
182  adaptor.getOperands()[0]);
183  return success();
184  }
185 };
186 
187 /// Conversion pattern for a vector.matrix_multiply.
188 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
189 class VectorMatmulOpConversion
190  : public ConvertOpToLLVMPattern<vector::MatmulOp> {
191 public:
193 
194  LogicalResult
195  matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
196  ConversionPatternRewriter &rewriter) const override {
197  rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
198  matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
199  adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
200  matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
201  return success();
202  }
203 };
204 
205 /// Conversion pattern for a vector.flat_transpose.
206 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
207 class VectorFlatTransposeOpConversion
208  : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
209 public:
211 
212  LogicalResult
213  matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
214  ConversionPatternRewriter &rewriter) const override {
215  rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
216  transOp, typeConverter->convertType(transOp.getRes().getType()),
217  adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
218  return success();
219  }
220 };
221 
222 /// Overloaded utility that replaces a vector.load, vector.store,
223 /// vector.maskedload and vector.maskedstore with their respective LLVM
224 /// couterparts.
225 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
226  vector::LoadOpAdaptor adaptor,
227  VectorType vectorTy, Value ptr, unsigned align,
228  ConversionPatternRewriter &rewriter) {
229  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
230  /*volatile_=*/false,
231  loadOp.getNontemporal());
232 }
233 
234 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
235  vector::MaskedLoadOpAdaptor adaptor,
236  VectorType vectorTy, Value ptr, unsigned align,
237  ConversionPatternRewriter &rewriter) {
238  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
239  loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
240 }
241 
242 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
243  vector::StoreOpAdaptor adaptor,
244  VectorType vectorTy, Value ptr, unsigned align,
245  ConversionPatternRewriter &rewriter) {
246  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
247  ptr, align, /*volatile_=*/false,
248  storeOp.getNontemporal());
249 }
250 
251 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
252  vector::MaskedStoreOpAdaptor adaptor,
253  VectorType vectorTy, Value ptr, unsigned align,
254  ConversionPatternRewriter &rewriter) {
255  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
256  storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
257 }
258 
259 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
260 /// vector.maskedstore.
261 template <class LoadOrStoreOp>
262 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
263 public:
264  explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
265  bool useVectorAlign)
266  : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
267  useVectorAlignment(useVectorAlign) {}
269 
270  LogicalResult
271  matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
272  typename LoadOrStoreOp::Adaptor adaptor,
273  ConversionPatternRewriter &rewriter) const override {
274  // Only 1-D vectors can be lowered to LLVM.
275  VectorType vectorTy = loadOrStoreOp.getVectorType();
276  if (vectorTy.getRank() > 1)
277  return failure();
278 
279  auto loc = loadOrStoreOp->getLoc();
280  MemRefType memRefTy = loadOrStoreOp.getMemRefType();
281 
282  // Resolve alignment.
283  unsigned align;
284  if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
285  memRefTy, align, useVectorAlignment)))
286  return rewriter.notifyMatchFailure(loadOrStoreOp,
287  "could not resolve alignment");
288 
289  // Resolve address.
290  auto vtype = cast<VectorType>(
291  this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
292  Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
293  adaptor.getIndices(), rewriter);
294  replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
295  rewriter);
296  return success();
297  }
298 
299 private:
300  // If true, use the preferred alignment of the vector type.
301  // If false, use the preferred alignment of the element type
302  // of the memref. This flag is intended for use with hardware
303  // backends that require alignment of vector operations.
304  const bool useVectorAlignment;
305 };
306 
307 /// Conversion pattern for a vector.gather.
308 class VectorGatherOpConversion
309  : public ConvertOpToLLVMPattern<vector::GatherOp> {
310 public:
311  explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
312  bool useVectorAlign)
313  : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
314  useVectorAlignment(useVectorAlign) {}
316 
317  LogicalResult
318  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
319  ConversionPatternRewriter &rewriter) const override {
320  Location loc = gather->getLoc();
321  MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
322  assert(memRefType && "The base should be bufferized");
323 
324  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
325  return rewriter.notifyMatchFailure(gather, "memref type not supported");
326 
327  VectorType vType = gather.getVectorType();
328  if (vType.getRank() > 1) {
329  return rewriter.notifyMatchFailure(
330  gather, "only 1-D vectors can be lowered to LLVM");
331  }
332 
333  // Resolve alignment.
334  unsigned align;
335  if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
336  memRefType, align, useVectorAlignment)))
337  return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
338 
339  // Resolve address.
340  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
341  adaptor.getIndices(), rewriter);
342  Value base = adaptor.getBase();
343  Value ptrs =
344  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
345  base, ptr, adaptor.getIndexVec(), vType);
346 
347  // Replace with the gather intrinsic.
348  rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
349  gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
350  adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
351  return success();
352  }
353 
354 private:
355  // If true, use the preferred alignment of the vector type.
356  // If false, use the preferred alignment of the element type
357  // of the memref. This flag is intended for use with hardware
358  // backends that require alignment of vector operations.
359  const bool useVectorAlignment;
360 };
361 
362 /// Conversion pattern for a vector.scatter.
363 class VectorScatterOpConversion
364  : public ConvertOpToLLVMPattern<vector::ScatterOp> {
365 public:
366  explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
367  bool useVectorAlign)
368  : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
369  useVectorAlignment(useVectorAlign) {}
370 
372 
373  LogicalResult
374  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
375  ConversionPatternRewriter &rewriter) const override {
376  auto loc = scatter->getLoc();
377  MemRefType memRefType = scatter.getMemRefType();
378 
379  if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
380  return rewriter.notifyMatchFailure(scatter, "memref type not supported");
381 
382  VectorType vType = scatter.getVectorType();
383  if (vType.getRank() > 1) {
384  return rewriter.notifyMatchFailure(
385  scatter, "only 1-D vectors can be lowered to LLVM");
386  }
387 
388  // Resolve alignment.
389  unsigned align;
390  if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
391  memRefType, align, useVectorAlignment)))
392  return rewriter.notifyMatchFailure(scatter,
393  "could not resolve alignment");
394 
395  // Resolve address.
396  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
397  adaptor.getIndices(), rewriter);
398  Value ptrs =
399  getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
400  adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
401 
402  // Replace with the scatter intrinsic.
403  rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
404  scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
405  rewriter.getI32IntegerAttr(align));
406  return success();
407  }
408 
409 private:
410  // If true, use the preferred alignment of the vector type.
411  // If false, use the preferred alignment of the element type
412  // of the memref. This flag is intended for use with hardware
413  // backends that require alignment of vector operations.
414  const bool useVectorAlignment;
415 };
416 
417 /// Conversion pattern for a vector.expandload.
418 class VectorExpandLoadOpConversion
419  : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
420 public:
422 
423  LogicalResult
424  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
425  ConversionPatternRewriter &rewriter) const override {
426  auto loc = expand->getLoc();
427  MemRefType memRefType = expand.getMemRefType();
428 
429  // Resolve address.
430  auto vtype = typeConverter->convertType(expand.getVectorType());
431  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
432  adaptor.getIndices(), rewriter);
433 
434  rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
435  expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
436  return success();
437  }
438 };
439 
440 /// Conversion pattern for a vector.compressstore.
441 class VectorCompressStoreOpConversion
442  : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
443 public:
445 
446  LogicalResult
447  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
448  ConversionPatternRewriter &rewriter) const override {
449  auto loc = compress->getLoc();
450  MemRefType memRefType = compress.getMemRefType();
451 
452  // Resolve address.
453  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
454  adaptor.getIndices(), rewriter);
455 
456  rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
457  compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
458  return success();
459  }
460 };
461 
462 /// Reduction neutral classes for overloading.
463 class ReductionNeutralZero {};
464 class ReductionNeutralIntOne {};
465 class ReductionNeutralFPOne {};
466 class ReductionNeutralAllOnes {};
467 class ReductionNeutralSIntMin {};
468 class ReductionNeutralUIntMin {};
469 class ReductionNeutralSIntMax {};
470 class ReductionNeutralUIntMax {};
471 class ReductionNeutralFPMin {};
472 class ReductionNeutralFPMax {};
473 
474 /// Create the reduction neutral zero value.
475 static Value createReductionNeutralValue(ReductionNeutralZero neutral,
476  ConversionPatternRewriter &rewriter,
477  Location loc, Type llvmType) {
478  return rewriter.create<LLVM::ConstantOp>(loc, llvmType,
479  rewriter.getZeroAttr(llvmType));
480 }
481 
482 /// Create the reduction neutral integer one value.
483 static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
484  ConversionPatternRewriter &rewriter,
485  Location loc, Type llvmType) {
486  return rewriter.create<LLVM::ConstantOp>(
487  loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
488 }
489 
490 /// Create the reduction neutral fp one value.
491 static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
492  ConversionPatternRewriter &rewriter,
493  Location loc, Type llvmType) {
494  return rewriter.create<LLVM::ConstantOp>(
495  loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
496 }
497 
498 /// Create the reduction neutral all-ones value.
499 static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
500  ConversionPatternRewriter &rewriter,
501  Location loc, Type llvmType) {
502  return rewriter.create<LLVM::ConstantOp>(
503  loc, llvmType,
504  rewriter.getIntegerAttr(
505  llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
506 }
507 
508 /// Create the reduction neutral signed int minimum value.
509 static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
510  ConversionPatternRewriter &rewriter,
511  Location loc, Type llvmType) {
512  return rewriter.create<LLVM::ConstantOp>(
513  loc, llvmType,
514  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
515  llvmType.getIntOrFloatBitWidth())));
516 }
517 
518 /// Create the reduction neutral unsigned int minimum value.
519 static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
520  ConversionPatternRewriter &rewriter,
521  Location loc, Type llvmType) {
522  return rewriter.create<LLVM::ConstantOp>(
523  loc, llvmType,
524  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
525  llvmType.getIntOrFloatBitWidth())));
526 }
527 
528 /// Create the reduction neutral signed int maximum value.
529 static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
530  ConversionPatternRewriter &rewriter,
531  Location loc, Type llvmType) {
532  return rewriter.create<LLVM::ConstantOp>(
533  loc, llvmType,
534  rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
535  llvmType.getIntOrFloatBitWidth())));
536 }
537 
538 /// Create the reduction neutral unsigned int maximum value.
539 static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
540  ConversionPatternRewriter &rewriter,
541  Location loc, Type llvmType) {
542  return rewriter.create<LLVM::ConstantOp>(
543  loc, llvmType,
544  rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
545  llvmType.getIntOrFloatBitWidth())));
546 }
547 
548 /// Create the reduction neutral fp minimum value.
549 static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
550  ConversionPatternRewriter &rewriter,
551  Location loc, Type llvmType) {
552  auto floatType = cast<FloatType>(llvmType);
553  return rewriter.create<LLVM::ConstantOp>(
554  loc, llvmType,
555  rewriter.getFloatAttr(
556  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
557  /*Negative=*/false)));
558 }
559 
560 /// Create the reduction neutral fp maximum value.
561 static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
562  ConversionPatternRewriter &rewriter,
563  Location loc, Type llvmType) {
564  auto floatType = cast<FloatType>(llvmType);
565  return rewriter.create<LLVM::ConstantOp>(
566  loc, llvmType,
567  rewriter.getFloatAttr(
568  llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
569  /*Negative=*/true)));
570 }
571 
572 /// Returns `accumulator` if it has a valid value. Otherwise, creates and
573 /// returns a new accumulator value using `ReductionNeutral`.
574 template <class ReductionNeutral>
575 static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
576  Location loc, Type llvmType,
577  Value accumulator) {
578  if (accumulator)
579  return accumulator;
580 
581  return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
582  llvmType);
583 }
584 
585 /// Creates a value with the 1-D vector shape provided in `llvmType`.
586 /// This is used as effective vector length by some intrinsics supporting
587 /// dynamic vector lengths at runtime.
588 static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
589  Location loc, Type llvmType) {
590  VectorType vType = cast<VectorType>(llvmType);
591  auto vShape = vType.getShape();
592  assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
593 
594  Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
595  loc, rewriter.getI32Type(),
596  rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
597 
598  if (!vType.getScalableDims()[0])
599  return baseVecLength;
600 
601  // For a scalable vector type, create and return `vScale * baseVecLength`.
602  Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
603  vScale =
604  rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
605  Value scalableVecLength =
606  rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
607  return scalableVecLength;
608 }
609 
610 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
611 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
612 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
613 /// non-null.
614 template <class LLVMRedIntrinOp, class ScalarOp>
615 static Value createIntegerReductionArithmeticOpLowering(
616  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
617  Value vectorOperand, Value accumulator) {
618 
619  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
620 
621  if (accumulator)
622  result = rewriter.create<ScalarOp>(loc, accumulator, result);
623  return result;
624 }
625 
626 /// Helper method to lower a `vector.reduction` operation that performs
627 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
628 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
629 /// the accumulator value if non-null.
630 template <class LLVMRedIntrinOp>
631 static Value createIntegerReductionComparisonOpLowering(
632  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
633  Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
634  Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
635  if (accumulator) {
636  Value cmp =
637  rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
638  result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
639  }
640  return result;
641 }
642 
643 namespace {
644 template <typename Source>
645 struct VectorToScalarMapper;
646 template <>
647 struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
648  using Type = LLVM::MaximumOp;
649 };
650 template <>
651 struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
652  using Type = LLVM::MinimumOp;
653 };
654 template <>
655 struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
656  using Type = LLVM::MaxNumOp;
657 };
658 template <>
659 struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
660  using Type = LLVM::MinNumOp;
661 };
662 } // namespace
663 
664 template <class LLVMRedIntrinOp>
665 static Value createFPReductionComparisonOpLowering(
666  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
667  Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
668  Value result =
669  rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand, fmf);
670 
671  if (accumulator) {
672  result =
673  rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
674  loc, result, accumulator);
675  }
676 
677  return result;
678 }
679 
680 /// Reduction neutral classes for overloading
681 class MaskNeutralFMaximum {};
682 class MaskNeutralFMinimum {};
683 
684 /// Get the mask neutral floating point maximum value
685 static llvm::APFloat
686 getMaskNeutralValue(MaskNeutralFMaximum,
687  const llvm::fltSemantics &floatSemantics) {
688  return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
689 }
690 /// Get the mask neutral floating point minimum value
691 static llvm::APFloat
692 getMaskNeutralValue(MaskNeutralFMinimum,
693  const llvm::fltSemantics &floatSemantics) {
694  return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
695 }
696 
697 /// Create the mask neutral floating point MLIR vector constant
698 template <typename MaskNeutral>
699 static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
700  Location loc, Type llvmType,
701  Type vectorType) {
702  const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
703  auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
704  auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
705  return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
706 }
707 
708 /// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
709 /// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
710 /// `fmaximum`/`fminimum`.
711 /// More information: https://github.com/llvm/llvm-project/issues/64940
712 template <class LLVMRedIntrinOp, class MaskNeutral>
713 static Value
714 lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
715  Location loc, Type llvmType,
716  Value vectorOperand, Value accumulator,
717  Value mask, LLVM::FastmathFlagsAttr fmf) {
718  const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
719  rewriter, loc, llvmType, vectorOperand.getType());
720  const Value selectedVectorByMask = rewriter.create<LLVM::SelectOp>(
721  loc, mask, vectorOperand, vectorMaskNeutral);
722  return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
723  rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
724 }
725 
726 template <class LLVMRedIntrinOp, class ReductionNeutral>
727 static Value
728 lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
729  Type llvmType, Value vectorOperand,
730  Value accumulator, LLVM::FastmathFlagsAttr fmf) {
731  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
732  llvmType, accumulator);
733  return rewriter.create<LLVMRedIntrinOp>(loc, llvmType,
734  /*startValue=*/accumulator,
735  vectorOperand, fmf);
736 }
737 
738 /// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
739 /// that requires a start value. This start value format spans across fp
740 /// reductions without mask and all the masked reduction intrinsics.
741 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
742 static Value
743 lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
744  Location loc, Type llvmType,
745  Value vectorOperand, Value accumulator) {
746  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
747  llvmType, accumulator);
748  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
749  /*startValue=*/accumulator,
750  vectorOperand);
751 }
752 
753 template <class LLVMVPRedIntrinOp, class ReductionNeutral>
754 static Value lowerPredicatedReductionWithStartValue(
755  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
756  Value vectorOperand, Value accumulator, Value mask) {
757  accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
758  llvmType, accumulator);
759  Value vectorLength =
760  createVectorLengthValue(rewriter, loc, vectorOperand.getType());
761  return rewriter.create<LLVMVPRedIntrinOp>(loc, llvmType,
762  /*startValue=*/accumulator,
763  vectorOperand, mask, vectorLength);
764 }
765 
766 template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
767  class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
768 static Value lowerPredicatedReductionWithStartValue(
769  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
770  Value vectorOperand, Value accumulator, Value mask) {
771  if (llvmType.isIntOrIndex())
772  return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
773  IntReductionNeutral>(
774  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
775 
776  // FP dispatch.
777  return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
778  FPReductionNeutral>(
779  rewriter, loc, llvmType, vectorOperand, accumulator, mask);
780 }
781 
782 /// Conversion pattern for all vector reductions.
783 class VectorReductionOpConversion
784  : public ConvertOpToLLVMPattern<vector::ReductionOp> {
785 public:
786  explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
787  bool reassociateFPRed)
788  : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
789  reassociateFPReductions(reassociateFPRed) {}
790 
791  LogicalResult
792  matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
793  ConversionPatternRewriter &rewriter) const override {
794  auto kind = reductionOp.getKind();
795  Type eltType = reductionOp.getDest().getType();
796  Type llvmType = typeConverter->convertType(eltType);
797  Value operand = adaptor.getVector();
798  Value acc = adaptor.getAcc();
799  Location loc = reductionOp.getLoc();
800 
801  if (eltType.isIntOrIndex()) {
802  // Integer reductions: add/mul/min/max/and/or/xor.
803  Value result;
804  switch (kind) {
805  case vector::CombiningKind::ADD:
806  result =
807  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
808  LLVM::AddOp>(
809  rewriter, loc, llvmType, operand, acc);
810  break;
811  case vector::CombiningKind::MUL:
812  result =
813  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
814  LLVM::MulOp>(
815  rewriter, loc, llvmType, operand, acc);
816  break;
818  result = createIntegerReductionComparisonOpLowering<
819  LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
820  LLVM::ICmpPredicate::ule);
821  break;
822  case vector::CombiningKind::MINSI:
823  result = createIntegerReductionComparisonOpLowering<
824  LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
825  LLVM::ICmpPredicate::sle);
826  break;
827  case vector::CombiningKind::MAXUI:
828  result = createIntegerReductionComparisonOpLowering<
829  LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
830  LLVM::ICmpPredicate::uge);
831  break;
832  case vector::CombiningKind::MAXSI:
833  result = createIntegerReductionComparisonOpLowering<
834  LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
835  LLVM::ICmpPredicate::sge);
836  break;
837  case vector::CombiningKind::AND:
838  result =
839  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
840  LLVM::AndOp>(
841  rewriter, loc, llvmType, operand, acc);
842  break;
843  case vector::CombiningKind::OR:
844  result =
845  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
846  LLVM::OrOp>(
847  rewriter, loc, llvmType, operand, acc);
848  break;
849  case vector::CombiningKind::XOR:
850  result =
851  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
852  LLVM::XOrOp>(
853  rewriter, loc, llvmType, operand, acc);
854  break;
855  default:
856  return failure();
857  }
858  rewriter.replaceOp(reductionOp, result);
859 
860  return success();
861  }
862 
863  if (!isa<FloatType>(eltType))
864  return failure();
865 
866  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
867  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
868  reductionOp.getContext(),
869  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
871  reductionOp.getContext(),
872  fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
873  : LLVM::FastmathFlags::none));
874 
875  // Floating-point reductions: add/mul/min/max
876  Value result;
877  if (kind == vector::CombiningKind::ADD) {
878  result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
879  ReductionNeutralZero>(
880  rewriter, loc, llvmType, operand, acc, fmf);
881  } else if (kind == vector::CombiningKind::MUL) {
882  result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
883  ReductionNeutralFPOne>(
884  rewriter, loc, llvmType, operand, acc, fmf);
885  } else if (kind == vector::CombiningKind::MINIMUMF) {
886  result =
887  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
888  rewriter, loc, llvmType, operand, acc, fmf);
889  } else if (kind == vector::CombiningKind::MAXIMUMF) {
890  result =
891  createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
892  rewriter, loc, llvmType, operand, acc, fmf);
893  } else if (kind == vector::CombiningKind::MINNUMF) {
894  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
895  rewriter, loc, llvmType, operand, acc, fmf);
896  } else if (kind == vector::CombiningKind::MAXNUMF) {
897  result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
898  rewriter, loc, llvmType, operand, acc, fmf);
899  } else {
900  return failure();
901  }
902 
903  rewriter.replaceOp(reductionOp, result);
904  return success();
905  }
906 
907 private:
908  const bool reassociateFPReductions;
909 };
910 
911 /// Base class to convert a `vector.mask` operation while matching traits
912 /// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
913 /// instance matches against a `vector.mask` operation. The `matchAndRewrite`
914 /// method performs a second match against the maskable operation `MaskedOp`.
915 /// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
916 /// implemented by the concrete conversion classes. This method can match
917 /// against specific traits of the `vector.mask` and the maskable operation. It
918 /// must replace the `vector.mask` operation.
919 template <class MaskedOp>
920 class VectorMaskOpConversionBase
921  : public ConvertOpToLLVMPattern<vector::MaskOp> {
922 public:
924 
925  LogicalResult
926  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
927  ConversionPatternRewriter &rewriter) const final {
928  // Match against the maskable operation kind.
929  auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
930  if (!maskedOp)
931  return failure();
932  return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
933  }
934 
935 protected:
936  virtual LogicalResult
937  matchAndRewriteMaskableOp(vector::MaskOp maskOp,
938  vector::MaskableOpInterface maskableOp,
939  ConversionPatternRewriter &rewriter) const = 0;
940 };
941 
942 class MaskedReductionOpConversion
943  : public VectorMaskOpConversionBase<vector::ReductionOp> {
944 
945 public:
946  using VectorMaskOpConversionBase<
947  vector::ReductionOp>::VectorMaskOpConversionBase;
948 
949  LogicalResult matchAndRewriteMaskableOp(
950  vector::MaskOp maskOp, MaskableOpInterface maskableOp,
951  ConversionPatternRewriter &rewriter) const override {
952  auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
953  auto kind = reductionOp.getKind();
954  Type eltType = reductionOp.getDest().getType();
955  Type llvmType = typeConverter->convertType(eltType);
956  Value operand = reductionOp.getVector();
957  Value acc = reductionOp.getAcc();
958  Location loc = reductionOp.getLoc();
959 
960  arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
961  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
962  reductionOp.getContext(),
963  convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
964 
965  Value result;
966  switch (kind) {
967  case vector::CombiningKind::ADD:
968  result = lowerPredicatedReductionWithStartValue<
969  LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
970  ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
971  maskOp.getMask());
972  break;
973  case vector::CombiningKind::MUL:
974  result = lowerPredicatedReductionWithStartValue<
975  LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
976  ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
977  maskOp.getMask());
978  break;
980  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
981  ReductionNeutralUIntMax>(
982  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
983  break;
984  case vector::CombiningKind::MINSI:
985  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
986  ReductionNeutralSIntMax>(
987  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
988  break;
989  case vector::CombiningKind::MAXUI:
990  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
991  ReductionNeutralUIntMin>(
992  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
993  break;
994  case vector::CombiningKind::MAXSI:
995  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
996  ReductionNeutralSIntMin>(
997  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
998  break;
999  case vector::CombiningKind::AND:
1000  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
1001  ReductionNeutralAllOnes>(
1002  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1003  break;
1004  case vector::CombiningKind::OR:
1005  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
1006  ReductionNeutralZero>(
1007  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1008  break;
1009  case vector::CombiningKind::XOR:
1010  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
1011  ReductionNeutralZero>(
1012  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1013  break;
1014  case vector::CombiningKind::MINNUMF:
1015  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1016  ReductionNeutralFPMax>(
1017  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1018  break;
1019  case vector::CombiningKind::MAXNUMF:
1020  result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1021  ReductionNeutralFPMin>(
1022  rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1023  break;
1024  case CombiningKind::MAXIMUMF:
1025  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1026  MaskNeutralFMaximum>(
1027  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1028  break;
1029  case CombiningKind::MINIMUMF:
1030  result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1031  MaskNeutralFMinimum>(
1032  rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1033  break;
1034  }
1035 
1036  // Replace `vector.mask` operation altogether.
1037  rewriter.replaceOp(maskOp, result);
1038  return success();
1039  }
1040 };
1041 
1042 class VectorShuffleOpConversion
1043  : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1044 public:
1046 
1047  LogicalResult
1048  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1049  ConversionPatternRewriter &rewriter) const override {
1050  auto loc = shuffleOp->getLoc();
1051  auto v1Type = shuffleOp.getV1VectorType();
1052  auto v2Type = shuffleOp.getV2VectorType();
1053  auto vectorType = shuffleOp.getResultVectorType();
1054  Type llvmType = typeConverter->convertType(vectorType);
1055  ArrayRef<int64_t> mask = shuffleOp.getMask();
1056 
1057  // Bail if result type cannot be lowered.
1058  if (!llvmType)
1059  return failure();
1060 
1061  // Get rank and dimension sizes.
1062  int64_t rank = vectorType.getRank();
1063 #ifndef NDEBUG
1064  bool wellFormed0DCase =
1065  v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1066  bool wellFormedNDCase =
1067  v1Type.getRank() == rank && v2Type.getRank() == rank;
1068  assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1069 #endif
1070 
1071  // For rank 0 and 1, where both operands have *exactly* the same vector
1072  // type, there is direct shuffle support in LLVM. Use it!
1073  if (rank <= 1 && v1Type == v2Type) {
1074  Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
1075  loc, adaptor.getV1(), adaptor.getV2(),
1076  llvm::to_vector_of<int32_t>(mask));
1077  rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1078  return success();
1079  }
1080 
1081  // For all other cases, insert the individual values individually.
1082  int64_t v1Dim = v1Type.getDimSize(0);
1083  Type eltType;
1084  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1085  eltType = arrayType.getElementType();
1086  else
1087  eltType = cast<VectorType>(llvmType).getElementType();
1088  Value insert = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
1089  int64_t insPos = 0;
1090  for (int64_t extPos : mask) {
1091  Value value = adaptor.getV1();
1092  if (extPos >= v1Dim) {
1093  extPos -= v1Dim;
1094  value = adaptor.getV2();
1095  }
1096  Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1097  eltType, rank, extPos);
1098  insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1099  llvmType, rank, insPos++);
1100  }
1101  rewriter.replaceOp(shuffleOp, insert);
1102  return success();
1103  }
1104 };
1105 
1106 class VectorExtractElementOpConversion
1107  : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
1108 public:
1109  using ConvertOpToLLVMPattern<
1110  vector::ExtractElementOp>::ConvertOpToLLVMPattern;
1111 
1112  LogicalResult
1113  matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
1114  ConversionPatternRewriter &rewriter) const override {
1115  auto vectorType = extractEltOp.getSourceVectorType();
1116  auto llvmType = typeConverter->convertType(vectorType.getElementType());
1117 
1118  // Bail if result type cannot be lowered.
1119  if (!llvmType)
1120  return failure();
1121 
1122  if (vectorType.getRank() == 0) {
1123  Location loc = extractEltOp.getLoc();
1124  auto idxType = rewriter.getIndexType();
1125  auto zero = rewriter.create<LLVM::ConstantOp>(
1126  loc, typeConverter->convertType(idxType),
1127  rewriter.getIntegerAttr(idxType, 0));
1128  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1129  extractEltOp, llvmType, adaptor.getVector(), zero);
1130  return success();
1131  }
1132 
1133  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
1134  extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
1135  return success();
1136  }
1137 };
1138 
1139 class VectorExtractOpConversion
1140  : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1141 public:
1143 
1144  LogicalResult
1145  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1146  ConversionPatternRewriter &rewriter) const override {
1147  auto loc = extractOp->getLoc();
1148  auto resultType = extractOp.getResult().getType();
1149  auto llvmResultType = typeConverter->convertType(resultType);
1150  // Bail if result type cannot be lowered.
1151  if (!llvmResultType)
1152  return failure();
1153 
1155  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1156 
1157  // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1158  // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1159  // from a N-d vector extract to a nested aggregate vector extract in two
1160  // steps:
1161  // - Extract a member from the nested aggregate. The result can be
1162  // a lower rank nested aggregate or a vector (1-D). This is done using
1163  // `llvm.extractvalue`.
1164  // - Extract a scalar out of the vector if needed. This is done using
1165  // `llvm.extractelement`.
1166 
1167  // Determine if we need to extract a member out of the aggregate. We
1168  // always need to extract a member if the input rank >= 2.
1169  bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1170  // Determine if we need to extract a scalar as the result. We extract
1171  // a scalar if the extract is full rank, i.e., the number of indices is
1172  // equal to source vector rank.
1173  bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1174  extractOp.getSourceVectorType().getRank();
1175 
1176  // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1177  // need to add a position for this change.
1178  if (extractOp.getSourceVectorType().getRank() == 0) {
1179  Type idxType = typeConverter->convertType(rewriter.getIndexType());
1180  positionVec.push_back(rewriter.getZeroAttr(idxType));
1181  }
1182 
1183  Value extracted = adaptor.getVector();
1184  if (extractsAggregate) {
1185  ArrayRef<OpFoldResult> position(positionVec);
1186  if (extractsScalar) {
1187  // If we are extracting a scalar from the extracted member, we drop
1188  // the last index, which will be used to extract the scalar out of the
1189  // vector.
1190  position = position.drop_back();
1191  }
1192  // llvm.extractvalue does not support dynamic dimensions.
1193  if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1194  return failure();
1195  }
1196  extracted = rewriter.create<LLVM::ExtractValueOp>(
1197  loc, extracted, getAsIntegers(position));
1198  }
1199 
1200  if (extractsScalar) {
1201  extracted = rewriter.create<LLVM::ExtractElementOp>(
1202  loc, extracted, getAsLLVMValue(rewriter, loc, positionVec.back()));
1203  }
1204 
1205  rewriter.replaceOp(extractOp, extracted);
1206  return success();
1207  }
1208 };
1209 
1210 /// Conversion pattern that turns a vector.fma on a 1-D vector
1211 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1212 /// This does not match vectors of n >= 2 rank.
1213 ///
1214 /// Example:
1215 /// ```
1216 /// vector.fma %a, %a, %a : vector<8xf32>
1217 /// ```
1218 /// is converted to:
1219 /// ```
1220 /// llvm.intr.fmuladd %va, %va, %va:
1221 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1222 /// -> !llvm."<8 x f32>">
1223 /// ```
1224 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1225 public:
1227 
1228  LogicalResult
1229  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1230  ConversionPatternRewriter &rewriter) const override {
1231  VectorType vType = fmaOp.getVectorType();
1232  if (vType.getRank() > 1)
1233  return failure();
1234 
1235  rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1236  fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1237  return success();
1238  }
1239 };
1240 
1241 class VectorInsertElementOpConversion
1242  : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
1243 public:
1245 
1246  LogicalResult
1247  matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
1248  ConversionPatternRewriter &rewriter) const override {
1249  auto vectorType = insertEltOp.getDestVectorType();
1250  auto llvmType = typeConverter->convertType(vectorType);
1251 
1252  // Bail if result type cannot be lowered.
1253  if (!llvmType)
1254  return failure();
1255 
1256  if (vectorType.getRank() == 0) {
1257  Location loc = insertEltOp.getLoc();
1258  auto idxType = rewriter.getIndexType();
1259  auto zero = rewriter.create<LLVM::ConstantOp>(
1260  loc, typeConverter->convertType(idxType),
1261  rewriter.getIntegerAttr(idxType, 0));
1262  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1263  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
1264  return success();
1265  }
1266 
1267  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1268  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
1269  adaptor.getPosition());
1270  return success();
1271  }
1272 };
1273 
1274 class VectorInsertOpConversion
1275  : public ConvertOpToLLVMPattern<vector::InsertOp> {
1276 public:
1278 
1279  LogicalResult
1280  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1281  ConversionPatternRewriter &rewriter) const override {
1282  auto loc = insertOp->getLoc();
1283  auto destVectorType = insertOp.getDestVectorType();
1284  auto llvmResultType = typeConverter->convertType(destVectorType);
1285  // Bail if result type cannot be lowered.
1286  if (!llvmResultType)
1287  return failure();
1288 
1290  adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1291 
1292  // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1293  // its explanatory comment about how N-D vectors are converted as nested
1294  // aggregates (llvm.array's) of 1D vectors.
1295  //
1296  // The innermost dimension of the destination vector, when converted to a
1297  // nested aggregate form, will always be a 1D vector.
1298  //
1299  // * If the insertion is happening into the innermost dimension of the
1300  // destination vector:
1301  // - If the destination is a nested aggregate, extract a 1D vector out of
1302  // the aggregate. This can be done using llvm.extractvalue. The
1303  // destination is now guaranteed to be a 1D vector, to which we are
1304  // inserting.
1305  // - Do the insertion into the 1D destination vector, and make the result
1306  // the new source nested aggregate. This can be done using
1307  // llvm.insertelement.
1308  // * Insert the source nested aggregate into the destination nested
1309  // aggregate.
1310 
1311  // Determine if we need to extract/insert a 1D vector out of the aggregate.
1312  bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1313  // Determine if we need to insert a scalar into the 1D vector.
1314  bool insertIntoInnermostDim =
1315  static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1316 
1317  ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1318  positionVec.begin(),
1319  insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1320  OpFoldResult positionOfScalarWithin1DVector;
1321  if (destVectorType.getRank() == 0) {
1322  // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1323  // need to create a 0 here as the position into the 1D vector.
1324  Type idxType = typeConverter->convertType(rewriter.getIndexType());
1325  positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1326  } else if (insertIntoInnermostDim) {
1327  positionOfScalarWithin1DVector = positionVec.back();
1328  }
1329 
1330  // We are going to mutate this 1D vector until it is either the final
1331  // result (in the non-aggregate case) or the value that needs to be
1332  // inserted into the aggregate result.
1333  Value sourceAggregate = adaptor.getValueToStore();
1334  if (insertIntoInnermostDim) {
1335  // Scalar-into-1D-vector case, so we know we will have to create a
1336  // InsertElementOp. The question is into what destination.
1337  if (isNestedAggregate) {
1338  // Aggregate case: the destination for the InsertElementOp needs to be
1339  // extracted from the aggregate.
1340  if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1341  llvm::IsaPred<Attribute>)) {
1342  // llvm.extractvalue does not support dynamic dimensions.
1343  return failure();
1344  }
1345  sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
1346  loc, adaptor.getDest(),
1347  getAsIntegers(positionOf1DVectorWithinAggregate));
1348  } else {
1349  // No-aggregate case. The destination for the InsertElementOp is just
1350  // the insertOp's destination.
1351  sourceAggregate = adaptor.getDest();
1352  }
1353  // Insert the scalar into the 1D vector.
1354  sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
1355  loc, sourceAggregate.getType(), sourceAggregate,
1356  adaptor.getValueToStore(),
1357  getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1358  }
1359 
1360  Value result = sourceAggregate;
1361  if (isNestedAggregate) {
1362  result = rewriter.create<LLVM::InsertValueOp>(
1363  loc, adaptor.getDest(), sourceAggregate,
1364  getAsIntegers(positionOf1DVectorWithinAggregate));
1365  }
1366 
1367  rewriter.replaceOp(insertOp, result);
1368  return success();
1369  }
1370 };
1371 
1372 /// Lower vector.scalable.insert ops to LLVM vector.insert
1373 struct VectorScalableInsertOpLowering
1374  : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1375  using ConvertOpToLLVMPattern<
1376  vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1377 
1378  LogicalResult
1379  matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1380  ConversionPatternRewriter &rewriter) const override {
1381  rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1382  insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1383  return success();
1384  }
1385 };
1386 
1387 /// Lower vector.scalable.extract ops to LLVM vector.extract
1388 struct VectorScalableExtractOpLowering
1389  : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1390  using ConvertOpToLLVMPattern<
1391  vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1392 
1393  LogicalResult
1394  matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1395  ConversionPatternRewriter &rewriter) const override {
1396  rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1397  extOp, typeConverter->convertType(extOp.getResultVectorType()),
1398  adaptor.getSource(), adaptor.getPos());
1399  return success();
1400  }
1401 };
1402 
1403 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1404 ///
1405 /// Example:
1406 /// ```
1407 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1408 /// ```
1409 /// is rewritten into:
1410 /// ```
1411 /// %r = splat %f0: vector<2x4xf32>
1412 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1413 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1414 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1415 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1416 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1417 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1418 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1419 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1420 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1421 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1422 /// // %r3 holds the final value.
1423 /// ```
1424 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1425 public:
1427 
1428  void initialize() {
1429  // This pattern recursively unpacks one dimension at a time. The recursion
1430  // bounded as the rank is strictly decreasing.
1431  setHasBoundedRewriteRecursion();
1432  }
1433 
1434  LogicalResult matchAndRewrite(FMAOp op,
1435  PatternRewriter &rewriter) const override {
1436  auto vType = op.getVectorType();
1437  if (vType.getRank() < 2)
1438  return failure();
1439 
1440  auto loc = op.getLoc();
1441  auto elemType = vType.getElementType();
1442  Value zero = rewriter.create<arith::ConstantOp>(
1443  loc, elemType, rewriter.getZeroAttr(elemType));
1444  Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
1445  for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1446  Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
1447  Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
1448  Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
1449  Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
1450  desc = rewriter.create<InsertOp>(loc, fma, desc, i);
1451  }
1452  rewriter.replaceOp(op, desc);
1453  return success();
1454  }
1455 };
1456 
1457 /// Returns the strides if the memory underlying `memRefType` has a contiguous
1458 /// static layout.
1459 static std::optional<SmallVector<int64_t, 4>>
1460 computeContiguousStrides(MemRefType memRefType) {
1461  int64_t offset;
1462  SmallVector<int64_t, 4> strides;
1463  if (failed(memRefType.getStridesAndOffset(strides, offset)))
1464  return std::nullopt;
1465  if (!strides.empty() && strides.back() != 1)
1466  return std::nullopt;
1467  // If no layout or identity layout, this is contiguous by definition.
1468  if (memRefType.getLayout().isIdentity())
1469  return strides;
1470 
1471  // Otherwise, we must determine contiguity form shapes. This can only ever
1472  // work in static cases because MemRefType is underspecified to represent
1473  // contiguous dynamic shapes in other ways than with just empty/identity
1474  // layout.
1475  auto sizes = memRefType.getShape();
1476  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1477  if (ShapedType::isDynamic(sizes[index + 1]) ||
1478  ShapedType::isDynamic(strides[index]) ||
1479  ShapedType::isDynamic(strides[index + 1]))
1480  return std::nullopt;
1481  if (strides[index] != strides[index + 1] * sizes[index + 1])
1482  return std::nullopt;
1483  }
1484  return strides;
1485 }
1486 
1487 class VectorTypeCastOpConversion
1488  : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1489 public:
1491 
1492  LogicalResult
1493  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1494  ConversionPatternRewriter &rewriter) const override {
1495  auto loc = castOp->getLoc();
1496  MemRefType sourceMemRefType =
1497  cast<MemRefType>(castOp.getOperand().getType());
1498  MemRefType targetMemRefType = castOp.getType();
1499 
1500  // Only static shape casts supported atm.
1501  if (!sourceMemRefType.hasStaticShape() ||
1502  !targetMemRefType.hasStaticShape())
1503  return failure();
1504 
1505  auto llvmSourceDescriptorTy =
1506  dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1507  if (!llvmSourceDescriptorTy)
1508  return failure();
1509  MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1510 
1511  auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1512  typeConverter->convertType(targetMemRefType));
1513  if (!llvmTargetDescriptorTy)
1514  return failure();
1515 
1516  // Only contiguous source buffers supported atm.
1517  auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1518  if (!sourceStrides)
1519  return failure();
1520  auto targetStrides = computeContiguousStrides(targetMemRefType);
1521  if (!targetStrides)
1522  return failure();
1523  // Only support static strides for now, regardless of contiguity.
1524  if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1525  return failure();
1526 
1527  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1528 
1529  // Create descriptor.
1530  auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1531  // Set allocated ptr.
1532  Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1533  desc.setAllocatedPtr(rewriter, loc, allocated);
1534 
1535  // Set aligned ptr.
1536  Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1537  desc.setAlignedPtr(rewriter, loc, ptr);
1538  // Fill offset 0.
1539  auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1540  auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
1541  desc.setOffset(rewriter, loc, zero);
1542 
1543  // Fill size and stride descriptors in memref.
1544  for (const auto &indexedSize :
1545  llvm::enumerate(targetMemRefType.getShape())) {
1546  int64_t index = indexedSize.index();
1547  auto sizeAttr =
1548  rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1549  auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
1550  desc.setSize(rewriter, loc, index, size);
1551  auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1552  (*targetStrides)[index]);
1553  auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
1554  desc.setStride(rewriter, loc, index, stride);
1555  }
1556 
1557  rewriter.replaceOp(castOp, {desc});
1558  return success();
1559  }
1560 };
1561 
1562 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1563 /// Non-scalable versions of this operation are handled in Vector Transforms.
1564 class VectorCreateMaskOpConversion
1565  : public OpConversionPattern<vector::CreateMaskOp> {
1566 public:
1567  explicit VectorCreateMaskOpConversion(MLIRContext *context,
1568  bool enableIndexOpt)
1569  : OpConversionPattern<vector::CreateMaskOp>(context),
1570  force32BitVectorIndices(enableIndexOpt) {}
1571 
1572  LogicalResult
1573  matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1574  ConversionPatternRewriter &rewriter) const override {
1575  auto dstType = op.getType();
1576  if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1577  return failure();
1578  IntegerType idxType =
1579  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1580  auto loc = op->getLoc();
1581  Value indices = rewriter.create<LLVM::StepVectorOp>(
1582  loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
1583  /*isScalable=*/true));
1584  auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1585  adaptor.getOperands()[0]);
1586  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
1587  Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
1588  indices, bounds);
1589  rewriter.replaceOp(op, comp);
1590  return success();
1591  }
1592 
1593 private:
1594  const bool force32BitVectorIndices;
1595 };
1596 
1597 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1598 public:
1600 
1601  // Lowering implementation that relies on a small runtime support library,
1602  // which only needs to provide a few printing methods (single value for all
1603  // data types, opening/closing bracket, comma, newline). The lowering splits
1604  // the vector into elementary printing operations. The advantage of this
1605  // approach is that the library can remain unaware of all low-level
1606  // implementation details of vectors while still supporting output of any
1607  // shaped and dimensioned vector.
1608  //
1609  // Note: This lowering only handles scalars, n-D vectors are broken into
1610  // printing scalars in loops in VectorToSCF.
1611  //
1612  // TODO: rely solely on libc in future? something else?
1613  //
1614  LogicalResult
1615  matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1616  ConversionPatternRewriter &rewriter) const override {
1617  auto parent = printOp->getParentOfType<ModuleOp>();
1618  if (!parent)
1619  return failure();
1620 
1621  auto loc = printOp->getLoc();
1622 
1623  if (auto value = adaptor.getSource()) {
1624  Type printType = printOp.getPrintType();
1625  if (isa<VectorType>(printType)) {
1626  // Vectors should be broken into elementary print ops in VectorToSCF.
1627  return failure();
1628  }
1629  if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1630  return failure();
1631  }
1632 
1633  auto punct = printOp.getPunctuation();
1634  if (auto stringLiteral = printOp.getStringLiteral()) {
1635  auto createResult =
1636  LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1637  *stringLiteral, *getTypeConverter(),
1638  /*addNewline=*/false);
1639  if (createResult.failed())
1640  return failure();
1641 
1642  } else if (punct != PrintPunctuation::NoPunctuation) {
1643  FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1644  switch (punct) {
1645  case PrintPunctuation::Close:
1646  return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
1647  case PrintPunctuation::Open:
1648  return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
1649  case PrintPunctuation::Comma:
1650  return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
1651  case PrintPunctuation::NewLine:
1652  return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
1653  default:
1654  llvm_unreachable("unexpected punctuation");
1655  }
1656  }();
1657  if (failed(op))
1658  return failure();
1659  emitCall(rewriter, printOp->getLoc(), op.value());
1660  }
1661 
1662  rewriter.eraseOp(printOp);
1663  return success();
1664  }
1665 
1666 private:
1667  enum class PrintConversion {
1668  // clang-format off
1669  None,
1670  ZeroExt64,
1671  SignExt64,
1672  Bitcast16
1673  // clang-format on
1674  };
1675 
1676  LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1677  ModuleOp parent, Location loc, Type printType,
1678  Value value) const {
1679  if (typeConverter->convertType(printType) == nullptr)
1680  return failure();
1681 
1682  // Make sure element type has runtime support.
1683  PrintConversion conversion = PrintConversion::None;
1684  FailureOr<Operation *> printer;
1685  if (printType.isF32()) {
1686  printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
1687  } else if (printType.isF64()) {
1688  printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
1689  } else if (printType.isF16()) {
1690  conversion = PrintConversion::Bitcast16; // bits!
1691  printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
1692  } else if (printType.isBF16()) {
1693  conversion = PrintConversion::Bitcast16; // bits!
1694  printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
1695  } else if (printType.isIndex()) {
1696  printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
1697  } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1698  // Integers need a zero or sign extension on the operand
1699  // (depending on the source type) as well as a signed or
1700  // unsigned print method. Up to 64-bit is supported.
1701  unsigned width = intTy.getWidth();
1702  if (intTy.isUnsigned()) {
1703  if (width <= 64) {
1704  if (width < 64)
1705  conversion = PrintConversion::ZeroExt64;
1706  printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
1707  } else {
1708  return failure();
1709  }
1710  } else {
1711  assert(intTy.isSignless() || intTy.isSigned());
1712  if (width <= 64) {
1713  // Note that we *always* zero extend booleans (1-bit integers),
1714  // so that true/false is printed as 1/0 rather than -1/0.
1715  if (width == 1)
1716  conversion = PrintConversion::ZeroExt64;
1717  else if (width < 64)
1718  conversion = PrintConversion::SignExt64;
1719  printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
1720  } else {
1721  return failure();
1722  }
1723  }
1724  } else {
1725  return failure();
1726  }
1727  if (failed(printer))
1728  return failure();
1729 
1730  switch (conversion) {
1731  case PrintConversion::ZeroExt64:
1732  value = rewriter.create<arith::ExtUIOp>(
1733  loc, IntegerType::get(rewriter.getContext(), 64), value);
1734  break;
1735  case PrintConversion::SignExt64:
1736  value = rewriter.create<arith::ExtSIOp>(
1737  loc, IntegerType::get(rewriter.getContext(), 64), value);
1738  break;
1739  case PrintConversion::Bitcast16:
1740  value = rewriter.create<LLVM::BitcastOp>(
1741  loc, IntegerType::get(rewriter.getContext(), 16), value);
1742  break;
1743  case PrintConversion::None:
1744  break;
1745  }
1746  emitCall(rewriter, loc, printer.value(), value);
1747  return success();
1748  }
1749 
1750  // Helper to emit a call.
1751  static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1752  Operation *ref, ValueRange params = ValueRange()) {
1753  rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1754  params);
1755  }
1756 };
1757 
1758 /// The Splat operation is lowered to an insertelement + a shufflevector
1759 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1760 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1762 
1763  LogicalResult
1764  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1765  ConversionPatternRewriter &rewriter) const override {
1766  VectorType resultType = cast<VectorType>(splatOp.getType());
1767  if (resultType.getRank() > 1)
1768  return failure();
1769 
1770  // First insert it into a poison vector so we can shuffle it.
1771  auto vectorType = typeConverter->convertType(splatOp.getType());
1772  Value poison =
1773  rewriter.create<LLVM::PoisonOp>(splatOp.getLoc(), vectorType);
1774  auto zero = rewriter.create<LLVM::ConstantOp>(
1775  splatOp.getLoc(),
1776  typeConverter->convertType(rewriter.getIntegerType(32)),
1777  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1778 
1779  // For 0-d vector, we simply do `insertelement`.
1780  if (resultType.getRank() == 0) {
1781  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1782  splatOp, vectorType, poison, adaptor.getInput(), zero);
1783  return success();
1784  }
1785 
1786  // For 1-d vector, we additionally do a `vectorshuffle`.
1787  auto v = rewriter.create<LLVM::InsertElementOp>(
1788  splatOp.getLoc(), vectorType, poison, adaptor.getInput(), zero);
1789 
1790  int64_t width = cast<VectorType>(splatOp.getType()).getDimSize(0);
1791  SmallVector<int32_t> zeroValues(width, 0);
1792 
1793  // Shuffle the value across the desired number of elements.
1794  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, poison,
1795  zeroValues);
1796  return success();
1797  }
1798 };
1799 
1800 /// The Splat operation is lowered to an insertelement + a shufflevector
1801 /// operation. Splat to only 2+-d vector result types are lowered by the
1802 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1803 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1805 
1806  LogicalResult
1807  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1808  ConversionPatternRewriter &rewriter) const override {
1809  VectorType resultType = splatOp.getType();
1810  if (resultType.getRank() <= 1)
1811  return failure();
1812 
1813  // First insert it into an undef vector so we can shuffle it.
1814  auto loc = splatOp.getLoc();
1815  auto vectorTypeInfo =
1816  LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1817  auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1818  auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1819  if (!llvmNDVectorTy || !llvm1DVectorTy)
1820  return failure();
1821 
1822  // Construct returned value.
1823  Value desc = rewriter.create<LLVM::PoisonOp>(loc, llvmNDVectorTy);
1824 
1825  // Construct a 1-D vector with the splatted value that we insert in all the
1826  // places within the returned descriptor.
1827  Value vdesc = rewriter.create<LLVM::PoisonOp>(loc, llvm1DVectorTy);
1828  auto zero = rewriter.create<LLVM::ConstantOp>(
1829  loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1830  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1831  Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1832  adaptor.getInput(), zero);
1833 
1834  // Shuffle the value across the desired number of elements.
1835  int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1836  SmallVector<int32_t> zeroValues(width, 0);
1837  v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroValues);
1838 
1839  // Iterate of linear index, convert to coords space and insert splatted 1-D
1840  // vector in each position.
1841  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1842  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, v, position);
1843  });
1844  rewriter.replaceOp(splatOp, desc);
1845  return success();
1846  }
1847 };
1848 
1849 /// Conversion pattern for a `vector.interleave`.
1850 /// This supports fixed-sized vectors and scalable vectors.
1851 struct VectorInterleaveOpLowering
1852  : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1854 
1855  LogicalResult
1856  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1857  ConversionPatternRewriter &rewriter) const override {
1858  VectorType resultType = interleaveOp.getResultVectorType();
1859  // n-D interleaves should have been lowered already.
1860  if (resultType.getRank() != 1)
1861  return rewriter.notifyMatchFailure(interleaveOp,
1862  "InterleaveOp not rank 1");
1863  // If the result is rank 1, then this directly maps to LLVM.
1864  if (resultType.isScalable()) {
1865  rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1866  interleaveOp, typeConverter->convertType(resultType),
1867  adaptor.getLhs(), adaptor.getRhs());
1868  return success();
1869  }
1870  // Lower fixed-size interleaves to a shufflevector. While the
1871  // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1872  // langref still recommends fixed-vectors use shufflevector, see:
1873  // https://llvm.org/docs/LangRef.html#id876.
1874  int64_t resultVectorSize = resultType.getNumElements();
1875  SmallVector<int32_t> interleaveShuffleMask;
1876  interleaveShuffleMask.reserve(resultVectorSize);
1877  for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1878  interleaveShuffleMask.push_back(i);
1879  interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1880  }
1881  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1882  interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1883  interleaveShuffleMask);
1884  return success();
1885  }
1886 };
1887 
1888 /// Conversion pattern for a `vector.deinterleave`.
1889 /// This supports fixed-sized vectors and scalable vectors.
1890 struct VectorDeinterleaveOpLowering
1891  : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1893 
1894  LogicalResult
1895  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1896  ConversionPatternRewriter &rewriter) const override {
1897  VectorType resultType = deinterleaveOp.getResultVectorType();
1898  VectorType sourceType = deinterleaveOp.getSourceVectorType();
1899  auto loc = deinterleaveOp.getLoc();
1900 
1901  // Note: n-D deinterleave operations should be lowered to the 1-D before
1902  // converting to LLVM.
1903  if (resultType.getRank() != 1)
1904  return rewriter.notifyMatchFailure(deinterleaveOp,
1905  "DeinterleaveOp not rank 1");
1906 
1907  if (resultType.isScalable()) {
1908  auto llvmTypeConverter = this->getTypeConverter();
1909  auto deinterleaveResults = deinterleaveOp.getResultTypes();
1910  auto packedOpResults =
1911  llvmTypeConverter->packOperationResults(deinterleaveResults);
1912  auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
1913  loc, packedOpResults, adaptor.getSource());
1914 
1915  auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
1916  loc, intrinsic->getResult(0), 0);
1917  auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
1918  loc, intrinsic->getResult(0), 1);
1919 
1920  rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1921  return success();
1922  }
1923  // Lower fixed-size deinterleave to two shufflevectors. While the
1924  // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1925  // langref still recommends fixed-vectors use shufflevector, see:
1926  // https://llvm.org/docs/LangRef.html#id889.
1927  int64_t resultVectorSize = resultType.getNumElements();
1928  SmallVector<int32_t> evenShuffleMask;
1929  SmallVector<int32_t> oddShuffleMask;
1930 
1931  evenShuffleMask.reserve(resultVectorSize);
1932  oddShuffleMask.reserve(resultVectorSize);
1933 
1934  for (int i = 0; i < sourceType.getNumElements(); ++i) {
1935  if (i % 2 == 0)
1936  evenShuffleMask.push_back(i);
1937  else
1938  oddShuffleMask.push_back(i);
1939  }
1940 
1941  auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
1942  auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1943  loc, adaptor.getSource(), poison, evenShuffleMask);
1944  auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
1945  loc, adaptor.getSource(), poison, oddShuffleMask);
1946 
1947  rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1948  return success();
1949  }
1950 };
1951 
1952 /// Conversion pattern for a `vector.from_elements`.
1953 struct VectorFromElementsLowering
1954  : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1956 
1957  LogicalResult
1958  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1959  ConversionPatternRewriter &rewriter) const override {
1960  Location loc = fromElementsOp.getLoc();
1961  VectorType vectorType = fromElementsOp.getType();
1962  // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
1963  // Such ops should be handled in the same way as vector.insert.
1964  if (vectorType.getRank() > 1)
1965  return rewriter.notifyMatchFailure(fromElementsOp,
1966  "rank > 1 vectors are not supported");
1967  Type llvmType = typeConverter->convertType(vectorType);
1968  Value result = rewriter.create<LLVM::PoisonOp>(loc, llvmType);
1969  for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
1970  result = rewriter.create<vector::InsertOp>(loc, val, result, idx);
1971  rewriter.replaceOp(fromElementsOp, result);
1972  return success();
1973  }
1974 };
1975 
1976 /// Conversion pattern for vector.step.
1977 struct VectorScalableStepOpLowering
1978  : public ConvertOpToLLVMPattern<vector::StepOp> {
1980 
1981  LogicalResult
1982  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1983  ConversionPatternRewriter &rewriter) const override {
1984  auto resultType = cast<VectorType>(stepOp.getType());
1985  if (!resultType.isScalable()) {
1986  return failure();
1987  }
1988  Type llvmType = typeConverter->convertType(stepOp.getType());
1989  rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1990  return success();
1991  }
1992 };
1993 
1994 } // namespace
1995 
1998  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
1999 }
2000 
2001 /// Populate the given list with patterns that convert from Vector to LLVM.
2003  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2004  bool reassociateFPReductions, bool force32BitVectorIndices,
2005  bool useVectorAlignment) {
2006  // This function populates only ConversionPatterns, not RewritePatterns.
2007  MLIRContext *ctx = converter.getDialect()->getContext();
2008  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2009  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2010  patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2011  VectorLoadStoreConversion<vector::MaskedLoadOp>,
2012  VectorLoadStoreConversion<vector::StoreOp>,
2013  VectorLoadStoreConversion<vector::MaskedStoreOp>,
2014  VectorGatherOpConversion, VectorScatterOpConversion>(
2015  converter, useVectorAlignment);
2016  patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2017  VectorExtractElementOpConversion, VectorExtractOpConversion,
2018  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
2019  VectorInsertOpConversion, VectorPrintOpConversion,
2020  VectorTypeCastOpConversion, VectorScaleOpConversion,
2021  VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2022  VectorSplatOpLowering, VectorSplatNdOpLowering,
2023  VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2024  MaskedReductionOpConversion, VectorInterleaveOpLowering,
2025  VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2026  VectorScalableStepOpLowering>(converter);
2027 }
2028 
2030  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
2031  patterns.add<VectorMatmulOpConversion>(converter);
2032  patterns.add<VectorFlatTransposeOpConversion>(converter);
2033 }
2034 
2035 namespace {
2036 struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2038  void loadDependentDialects(MLIRContext *context) const final {
2039  context->loadDialect<LLVM::LLVMDialect>();
2040  }
2041 
2042  /// Hook for derived dialect interface to provide conversion patterns
2043  /// and mark dialect legal for the conversion target.
2044  void populateConvertToLLVMConversionPatterns(
2045  ConversionTarget &target, LLVMTypeConverter &typeConverter,
2046  RewritePatternSet &patterns) const final {
2048  }
2049 };
2050 } // namespace
2051 
2053  DialectRegistry &registry) {
2054  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2055  dialect->addInterfaces<VectorToLLVMDialectInterface>();
2056  });
2057 }
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
@ None
#define MINUI(lhs, rhs)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition: Unit.cpp:19
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:161
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition: TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
Definition: TypeToLLVM.cpp:183
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:207
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:112
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp)
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:841
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={})
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void registerConvertVectorToLLVMInterface(DialectRegistry &registry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
Include the generated interface declarations.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void populateVectorToLLVMMatrixConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314