MLIR  15.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 // Helper to reduce vector type by one rank at front.
27 static VectorType reducedVectorTypeFront(VectorType tp) {
28  assert((tp.getRank() > 1) && "unlowerable vector type");
29  unsigned numScalableDims = tp.getNumScalableDims();
30  if (tp.getShape().size() == numScalableDims)
31  --numScalableDims;
32  return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
33  numScalableDims);
34 }
35 
36 // Helper to reduce vector type by *all* but one rank at back.
37 static VectorType reducedVectorTypeBack(VectorType tp) {
38  assert((tp.getRank() > 1) && "unlowerable vector type");
39  unsigned numScalableDims = tp.getNumScalableDims();
40  if (numScalableDims > 0)
41  --numScalableDims;
42  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
43  numScalableDims);
44 }
45 
46 // Helper that picks the proper sequence for inserting.
48  LLVMTypeConverter &typeConverter, Location loc,
49  Value val1, Value val2, Type llvmType, int64_t rank,
50  int64_t pos) {
51  assert(rank > 0 && "0-D vector corner case should have been handled already");
52  if (rank == 1) {
53  auto idxType = rewriter.getIndexType();
54  auto constant = rewriter.create<LLVM::ConstantOp>(
55  loc, typeConverter.convertType(idxType),
56  rewriter.getIntegerAttr(idxType, pos));
57  return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
58  constant);
59  }
60  return rewriter.create<LLVM::InsertValueOp>(loc, llvmType, val1, val2,
61  rewriter.getI64ArrayAttr(pos));
62 }
63 
64 // Helper that picks the proper sequence for extracting.
66  LLVMTypeConverter &typeConverter, Location loc,
67  Value val, Type llvmType, int64_t rank, int64_t pos) {
68  if (rank <= 1) {
69  auto idxType = rewriter.getIndexType();
70  auto constant = rewriter.create<LLVM::ConstantOp>(
71  loc, typeConverter.convertType(idxType),
72  rewriter.getIntegerAttr(idxType, pos));
73  return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
74  constant);
75  }
76  return rewriter.create<LLVM::ExtractValueOp>(loc, llvmType, val,
77  rewriter.getI64ArrayAttr(pos));
78 }
79 
80 // Helper that returns data layout alignment of a memref.
82  MemRefType memrefType, unsigned &align) {
83  Type elementTy = typeConverter.convertType(memrefType.getElementType());
84  if (!elementTy)
85  return failure();
86 
87  // TODO: this should use the MLIR data layout when it becomes available and
88  // stop depending on translation.
89  llvm::LLVMContext llvmContext;
90  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
91  .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
92  return success();
93 }
94 
95 // Add an index vector component to a base pointer. This almost always succeeds
96 // unless the last stride is non-unit or the memory space is not zero.
98  Location loc, Value memref, Value base,
99  Value index, MemRefType memRefType,
100  VectorType vType, Value &ptrs) {
101  int64_t offset;
102  SmallVector<int64_t, 4> strides;
103  auto successStrides = getStridesAndOffset(memRefType, strides, offset);
104  if (failed(successStrides) || strides.back() != 1 ||
105  memRefType.getMemorySpaceAsInt() != 0)
106  return failure();
107  auto pType = MemRefDescriptor(memref).getElementPtrType();
108  auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
109  ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
110  return success();
111 }
112 
113 // Casts a strided element pointer to a vector pointer. The vector pointer
114 // will be in the same address space as the incoming memref type.
116  Value ptr, MemRefType memRefType, Type vt) {
117  auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt());
118  return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
119 }
120 
121 namespace {
122 
123 /// Trivial Vector to LLVM conversions
124 using VectorScaleOpConversion =
126 
127 /// Conversion pattern for a vector.bitcast.
128 class VectorBitCastOpConversion
129  : public ConvertOpToLLVMPattern<vector::BitCastOp> {
130 public:
132 
134  matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
135  ConversionPatternRewriter &rewriter) const override {
136  // Only 0-D and 1-D vectors can be lowered to LLVM.
137  VectorType resultTy = bitCastOp.getResultVectorType();
138  if (resultTy.getRank() > 1)
139  return failure();
140  Type newResultTy = typeConverter->convertType(resultTy);
141  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
142  adaptor.getOperands()[0]);
143  return success();
144  }
145 };
146 
147 /// Conversion pattern for a vector.matrix_multiply.
148 /// This is lowered directly to the proper llvm.intr.matrix.multiply.
149 class VectorMatmulOpConversion
150  : public ConvertOpToLLVMPattern<vector::MatmulOp> {
151 public:
153 
155  matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
156  ConversionPatternRewriter &rewriter) const override {
157  rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
158  matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
159  adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
160  matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
161  return success();
162  }
163 };
164 
165 /// Conversion pattern for a vector.flat_transpose.
166 /// This is lowered directly to the proper llvm.intr.matrix.transpose.
167 class VectorFlatTransposeOpConversion
168  : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
169 public:
171 
173  matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
174  ConversionPatternRewriter &rewriter) const override {
175  rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
176  transOp, typeConverter->convertType(transOp.getRes().getType()),
177  adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
178  return success();
179  }
180 };
181 
182 /// Overloaded utility that replaces a vector.load, vector.store,
183 /// vector.maskedload and vector.maskedstore with their respective LLVM
184 /// couterparts.
185 static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
186  vector::LoadOpAdaptor adaptor,
187  VectorType vectorTy, Value ptr, unsigned align,
188  ConversionPatternRewriter &rewriter) {
189  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, ptr, align);
190 }
191 
192 static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
193  vector::MaskedLoadOpAdaptor adaptor,
194  VectorType vectorTy, Value ptr, unsigned align,
195  ConversionPatternRewriter &rewriter) {
196  rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
197  loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
198 }
199 
200 static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
201  vector::StoreOpAdaptor adaptor,
202  VectorType vectorTy, Value ptr, unsigned align,
203  ConversionPatternRewriter &rewriter) {
204  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
205  ptr, align);
206 }
207 
208 static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
209  vector::MaskedStoreOpAdaptor adaptor,
210  VectorType vectorTy, Value ptr, unsigned align,
211  ConversionPatternRewriter &rewriter) {
212  rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
213  storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
214 }
215 
216 /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
217 /// vector.maskedstore.
218 template <class LoadOrStoreOp, class LoadOrStoreOpAdaptor>
219 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
220 public:
222 
224  matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
225  typename LoadOrStoreOp::Adaptor adaptor,
226  ConversionPatternRewriter &rewriter) const override {
227  // Only 1-D vectors can be lowered to LLVM.
228  VectorType vectorTy = loadOrStoreOp.getVectorType();
229  if (vectorTy.getRank() > 1)
230  return failure();
231 
232  auto loc = loadOrStoreOp->getLoc();
233  MemRefType memRefTy = loadOrStoreOp.getMemRefType();
234 
235  // Resolve alignment.
236  unsigned align;
237  if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
238  return failure();
239 
240  // Resolve address.
241  auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType())
242  .template cast<VectorType>();
243  Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
244  adaptor.getIndices(), rewriter);
245  Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype);
246 
247  replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter);
248  return success();
249  }
250 };
251 
252 /// Conversion pattern for a vector.gather.
253 class VectorGatherOpConversion
254  : public ConvertOpToLLVMPattern<vector::GatherOp> {
255 public:
257 
259  matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
260  ConversionPatternRewriter &rewriter) const override {
261  auto loc = gather->getLoc();
262  MemRefType memRefType = gather.getMemRefType();
263 
264  // Resolve alignment.
265  unsigned align;
266  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
267  return failure();
268 
269  // Resolve address.
270  Value ptrs;
271  VectorType vType = gather.getVectorType();
272  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
273  adaptor.getIndices(), rewriter);
274  if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
275  adaptor.getIndexVec(), memRefType, vType, ptrs)))
276  return failure();
277 
278  // Replace with the gather intrinsic.
279  rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
280  gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
281  adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
282  return success();
283  }
284 };
285 
286 /// Conversion pattern for a vector.scatter.
287 class VectorScatterOpConversion
288  : public ConvertOpToLLVMPattern<vector::ScatterOp> {
289 public:
291 
293  matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
294  ConversionPatternRewriter &rewriter) const override {
295  auto loc = scatter->getLoc();
296  MemRefType memRefType = scatter.getMemRefType();
297 
298  // Resolve alignment.
299  unsigned align;
300  if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
301  return failure();
302 
303  // Resolve address.
304  Value ptrs;
305  VectorType vType = scatter.getVectorType();
306  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
307  adaptor.getIndices(), rewriter);
308  if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr,
309  adaptor.getIndexVec(), memRefType, vType, ptrs)))
310  return failure();
311 
312  // Replace with the scatter intrinsic.
313  rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
314  scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
315  rewriter.getI32IntegerAttr(align));
316  return success();
317  }
318 };
319 
320 /// Conversion pattern for a vector.expandload.
321 class VectorExpandLoadOpConversion
322  : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
323 public:
325 
327  matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
328  ConversionPatternRewriter &rewriter) const override {
329  auto loc = expand->getLoc();
330  MemRefType memRefType = expand.getMemRefType();
331 
332  // Resolve address.
333  auto vtype = typeConverter->convertType(expand.getVectorType());
334  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
335  adaptor.getIndices(), rewriter);
336 
337  rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
338  expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
339  return success();
340  }
341 };
342 
343 /// Conversion pattern for a vector.compressstore.
344 class VectorCompressStoreOpConversion
345  : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
346 public:
348 
350  matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
351  ConversionPatternRewriter &rewriter) const override {
352  auto loc = compress->getLoc();
353  MemRefType memRefType = compress.getMemRefType();
354 
355  // Resolve address.
356  Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
357  adaptor.getIndices(), rewriter);
358 
359  rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
360  compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
361  return success();
362  }
363 };
364 
365 /// Helper method to lower a `vector.reduction` op that performs an arithmetic
366 /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
367 /// and `ScalarOp` is the scalar operation used to add the accumulation value if
368 /// non-null.
369 template <class VectorOp, class ScalarOp>
370 static Value createIntegerReductionArithmeticOpLowering(
371  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
372  Value vectorOperand, Value accumulator) {
373  Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
374  if (accumulator)
375  result = rewriter.create<ScalarOp>(loc, accumulator, result);
376  return result;
377 }
378 
379 /// Helper method to lower a `vector.reduction` operation that performs
380 /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
381 /// intrinsic to use and `predicate` is the predicate to use to compare+combine
382 /// the accumulator value if non-null.
383 template <class VectorOp>
384 static Value createIntegerReductionComparisonOpLowering(
385  ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
386  Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
387  Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
388  if (accumulator) {
389  Value cmp =
390  rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
391  result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
392  }
393  return result;
394 }
395 
396 /// Conversion pattern for all vector reductions.
397 class VectorReductionOpConversion
398  : public ConvertOpToLLVMPattern<vector::ReductionOp> {
399 public:
400  explicit VectorReductionOpConversion(LLVMTypeConverter &typeConv,
401  bool reassociateFPRed)
403  reassociateFPReductions(reassociateFPRed) {}
404 
406  matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
407  ConversionPatternRewriter &rewriter) const override {
408  auto kind = reductionOp.getKind();
409  Type eltType = reductionOp.getDest().getType();
410  Type llvmType = typeConverter->convertType(eltType);
411  Value operand = adaptor.getVector();
412  Value acc = adaptor.getAcc();
413  Location loc = reductionOp.getLoc();
414  if (eltType.isIntOrIndex()) {
415  // Integer reductions: add/mul/min/max/and/or/xor.
416  Value result;
417  switch (kind) {
418  case vector::CombiningKind::ADD:
419  result =
420  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
421  LLVM::AddOp>(
422  rewriter, loc, llvmType, operand, acc);
423  break;
424  case vector::CombiningKind::MUL:
425  result =
426  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
427  LLVM::MulOp>(
428  rewriter, loc, llvmType, operand, acc);
429  break;
430  case vector::CombiningKind::MINUI:
431  result = createIntegerReductionComparisonOpLowering<
432  LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
433  LLVM::ICmpPredicate::ule);
434  break;
435  case vector::CombiningKind::MINSI:
436  result = createIntegerReductionComparisonOpLowering<
437  LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
438  LLVM::ICmpPredicate::sle);
439  break;
440  case vector::CombiningKind::MAXUI:
441  result = createIntegerReductionComparisonOpLowering<
442  LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
443  LLVM::ICmpPredicate::uge);
444  break;
445  case vector::CombiningKind::MAXSI:
446  result = createIntegerReductionComparisonOpLowering<
447  LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
448  LLVM::ICmpPredicate::sge);
449  break;
450  case vector::CombiningKind::AND:
451  result =
452  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
453  LLVM::AndOp>(
454  rewriter, loc, llvmType, operand, acc);
455  break;
456  case vector::CombiningKind::OR:
457  result =
458  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
459  LLVM::OrOp>(
460  rewriter, loc, llvmType, operand, acc);
461  break;
462  case vector::CombiningKind::XOR:
463  result =
464  createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
465  LLVM::XOrOp>(
466  rewriter, loc, llvmType, operand, acc);
467  break;
468  default:
469  return failure();
470  }
471  rewriter.replaceOp(reductionOp, result);
472 
473  return success();
474  }
475 
476  if (!eltType.isa<FloatType>())
477  return failure();
478 
479  // Floating-point reductions: add/mul/min/max
480  if (kind == vector::CombiningKind::ADD) {
481  // Optional accumulator (or zero).
482  Value acc = adaptor.getOperands().size() > 1
483  ? adaptor.getOperands()[1]
484  : rewriter.create<LLVM::ConstantOp>(
485  reductionOp->getLoc(), llvmType,
486  rewriter.getZeroAttr(eltType));
487  rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fadd>(
488  reductionOp, llvmType, acc, operand,
489  rewriter.getBoolAttr(reassociateFPReductions));
490  } else if (kind == vector::CombiningKind::MUL) {
491  // Optional accumulator (or one).
492  Value acc = adaptor.getOperands().size() > 1
493  ? adaptor.getOperands()[1]
494  : rewriter.create<LLVM::ConstantOp>(
495  reductionOp->getLoc(), llvmType,
496  rewriter.getFloatAttr(eltType, 1.0));
497  rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
498  reductionOp, llvmType, acc, operand,
499  rewriter.getBoolAttr(reassociateFPReductions));
500  } else if (kind == vector::CombiningKind::MINF)
501  // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
502  // NaNs/-0.0/+0.0 in the same way.
503  rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
504  llvmType, operand);
505  else if (kind == vector::CombiningKind::MAXF)
506  // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
507  // NaNs/-0.0/+0.0 in the same way.
508  rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
509  llvmType, operand);
510  else
511  return failure();
512  return success();
513  }
514 
515 private:
516  const bool reassociateFPReductions;
517 };
518 
519 class VectorShuffleOpConversion
520  : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
521 public:
523 
525  matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
526  ConversionPatternRewriter &rewriter) const override {
527  auto loc = shuffleOp->getLoc();
528  auto v1Type = shuffleOp.getV1VectorType();
529  auto v2Type = shuffleOp.getV2VectorType();
530  auto vectorType = shuffleOp.getVectorType();
531  Type llvmType = typeConverter->convertType(vectorType);
532  auto maskArrayAttr = shuffleOp.getMask();
533 
534  // Bail if result type cannot be lowered.
535  if (!llvmType)
536  return failure();
537 
538  // Get rank and dimension sizes.
539  int64_t rank = vectorType.getRank();
540  assert(v1Type.getRank() == rank);
541  assert(v2Type.getRank() == rank);
542  int64_t v1Dim = v1Type.getDimSize(0);
543 
544  // For rank 1, where both operands have *exactly* the same vector type,
545  // there is direct shuffle support in LLVM. Use it!
546  if (rank == 1 && v1Type == v2Type) {
547  Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
548  loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr);
549  rewriter.replaceOp(shuffleOp, llvmShuffleOp);
550  return success();
551  }
552 
553  // For all other cases, insert the individual values individually.
554  Type eltType;
555  if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
556  eltType = arrayType.getElementType();
557  else
558  eltType = llvmType.cast<VectorType>().getElementType();
559  Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
560  int64_t insPos = 0;
561  for (const auto &en : llvm::enumerate(maskArrayAttr)) {
562  int64_t extPos = en.value().cast<IntegerAttr>().getInt();
563  Value value = adaptor.getV1();
564  if (extPos >= v1Dim) {
565  extPos -= v1Dim;
566  value = adaptor.getV2();
567  }
568  Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
569  eltType, rank, extPos);
570  insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
571  llvmType, rank, insPos++);
572  }
573  rewriter.replaceOp(shuffleOp, insert);
574  return success();
575  }
576 };
577 
578 class VectorExtractElementOpConversion
579  : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
580 public:
582  vector::ExtractElementOp>::ConvertOpToLLVMPattern;
583 
585  matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
586  ConversionPatternRewriter &rewriter) const override {
587  auto vectorType = extractEltOp.getVectorType();
588  auto llvmType = typeConverter->convertType(vectorType.getElementType());
589 
590  // Bail if result type cannot be lowered.
591  if (!llvmType)
592  return failure();
593 
594  if (vectorType.getRank() == 0) {
595  Location loc = extractEltOp.getLoc();
596  auto idxType = rewriter.getIndexType();
597  auto zero = rewriter.create<LLVM::ConstantOp>(
598  loc, typeConverter->convertType(idxType),
599  rewriter.getIntegerAttr(idxType, 0));
600  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
601  extractEltOp, llvmType, adaptor.getVector(), zero);
602  return success();
603  }
604 
605  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
606  extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
607  return success();
608  }
609 };
610 
611 class VectorExtractOpConversion
612  : public ConvertOpToLLVMPattern<vector::ExtractOp> {
613 public:
615 
617  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
618  ConversionPatternRewriter &rewriter) const override {
619  auto loc = extractOp->getLoc();
620  auto vectorType = extractOp.getVectorType();
621  auto resultType = extractOp.getResult().getType();
622  auto llvmResultType = typeConverter->convertType(resultType);
623  auto positionArrayAttr = extractOp.getPosition();
624 
625  // Bail if result type cannot be lowered.
626  if (!llvmResultType)
627  return failure();
628 
629  // Extract entire vector. Should be handled by folder, but just to be safe.
630  if (positionArrayAttr.empty()) {
631  rewriter.replaceOp(extractOp, adaptor.getVector());
632  return success();
633  }
634 
635  // One-shot extraction of vector from array (only requires extractvalue).
636  if (resultType.isa<VectorType>()) {
637  Value extracted = rewriter.create<LLVM::ExtractValueOp>(
638  loc, llvmResultType, adaptor.getVector(), positionArrayAttr);
639  rewriter.replaceOp(extractOp, extracted);
640  return success();
641  }
642 
643  // Potential extraction of 1-D vector from array.
644  auto *context = extractOp->getContext();
645  Value extracted = adaptor.getVector();
646  auto positionAttrs = positionArrayAttr.getValue();
647  if (positionAttrs.size() > 1) {
648  auto oneDVectorType = reducedVectorTypeBack(vectorType);
649  auto nMinusOnePositionAttrs =
650  ArrayAttr::get(context, positionAttrs.drop_back());
651  extracted = rewriter.create<LLVM::ExtractValueOp>(
652  loc, typeConverter->convertType(oneDVectorType), extracted,
653  nMinusOnePositionAttrs);
654  }
655 
656  // Remaining extraction of element from 1-D LLVM vector
657  auto position = positionAttrs.back().cast<IntegerAttr>();
658  auto i64Type = IntegerType::get(rewriter.getContext(), 64);
659  auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
660  extracted =
661  rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
662  rewriter.replaceOp(extractOp, extracted);
663 
664  return success();
665  }
666 };
667 
668 /// Conversion pattern that turns a vector.fma on a 1-D vector
669 /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
670 /// This does not match vectors of n >= 2 rank.
671 ///
672 /// Example:
673 /// ```
674 /// vector.fma %a, %a, %a : vector<8xf32>
675 /// ```
676 /// is converted to:
677 /// ```
678 /// llvm.intr.fmuladd %va, %va, %va:
679 /// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
680 /// -> !llvm."<8 x f32>">
681 /// ```
682 class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
683 public:
685 
687  matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
688  ConversionPatternRewriter &rewriter) const override {
689  VectorType vType = fmaOp.getVectorType();
690  if (vType.getRank() != 1)
691  return failure();
692  rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
693  fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
694  return success();
695  }
696 };
697 
698 class VectorInsertElementOpConversion
699  : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
700 public:
702 
704  matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
705  ConversionPatternRewriter &rewriter) const override {
706  auto vectorType = insertEltOp.getDestVectorType();
707  auto llvmType = typeConverter->convertType(vectorType);
708 
709  // Bail if result type cannot be lowered.
710  if (!llvmType)
711  return failure();
712 
713  if (vectorType.getRank() == 0) {
714  Location loc = insertEltOp.getLoc();
715  auto idxType = rewriter.getIndexType();
716  auto zero = rewriter.create<LLVM::ConstantOp>(
717  loc, typeConverter->convertType(idxType),
718  rewriter.getIntegerAttr(idxType, 0));
719  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
720  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
721  return success();
722  }
723 
724  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
725  insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
726  adaptor.getPosition());
727  return success();
728  }
729 };
730 
731 class VectorInsertOpConversion
732  : public ConvertOpToLLVMPattern<vector::InsertOp> {
733 public:
735 
737  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
738  ConversionPatternRewriter &rewriter) const override {
739  auto loc = insertOp->getLoc();
740  auto sourceType = insertOp.getSourceType();
741  auto destVectorType = insertOp.getDestVectorType();
742  auto llvmResultType = typeConverter->convertType(destVectorType);
743  auto positionArrayAttr = insertOp.getPosition();
744 
745  // Bail if result type cannot be lowered.
746  if (!llvmResultType)
747  return failure();
748 
749  // Overwrite entire vector with value. Should be handled by folder, but
750  // just to be safe.
751  if (positionArrayAttr.empty()) {
752  rewriter.replaceOp(insertOp, adaptor.getSource());
753  return success();
754  }
755 
756  // One-shot insertion of a vector into an array (only requires insertvalue).
757  if (sourceType.isa<VectorType>()) {
758  Value inserted = rewriter.create<LLVM::InsertValueOp>(
759  loc, llvmResultType, adaptor.getDest(), adaptor.getSource(),
760  positionArrayAttr);
761  rewriter.replaceOp(insertOp, inserted);
762  return success();
763  }
764 
765  // Potential extraction of 1-D vector from array.
766  auto *context = insertOp->getContext();
767  Value extracted = adaptor.getDest();
768  auto positionAttrs = positionArrayAttr.getValue();
769  auto position = positionAttrs.back().cast<IntegerAttr>();
770  auto oneDVectorType = destVectorType;
771  if (positionAttrs.size() > 1) {
772  oneDVectorType = reducedVectorTypeBack(destVectorType);
773  auto nMinusOnePositionAttrs =
774  ArrayAttr::get(context, positionAttrs.drop_back());
775  extracted = rewriter.create<LLVM::ExtractValueOp>(
776  loc, typeConverter->convertType(oneDVectorType), extracted,
777  nMinusOnePositionAttrs);
778  }
779 
780  // Insertion of an element into a 1-D LLVM vector.
781  auto i64Type = IntegerType::get(rewriter.getContext(), 64);
782  auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
783  Value inserted = rewriter.create<LLVM::InsertElementOp>(
784  loc, typeConverter->convertType(oneDVectorType), extracted,
785  adaptor.getSource(), constant);
786 
787  // Potential insertion of resulting 1-D vector into array.
788  if (positionAttrs.size() > 1) {
789  auto nMinusOnePositionAttrs =
790  ArrayAttr::get(context, positionAttrs.drop_back());
791  inserted = rewriter.create<LLVM::InsertValueOp>(
792  loc, llvmResultType, adaptor.getDest(), inserted,
793  nMinusOnePositionAttrs);
794  }
795 
796  rewriter.replaceOp(insertOp, inserted);
797  return success();
798  }
799 };
800 
801 /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
802 ///
803 /// Example:
804 /// ```
805 /// %d = vector.fma %a, %b, %c : vector<2x4xf32>
806 /// ```
807 /// is rewritten into:
808 /// ```
809 /// %r = splat %f0: vector<2x4xf32>
810 /// %va = vector.extractvalue %a[0] : vector<2x4xf32>
811 /// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
812 /// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
813 /// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
814 /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
815 /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
816 /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
817 /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
818 /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
819 /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
820 /// // %r3 holds the final value.
821 /// ```
822 class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
823 public:
825 
826  void initialize() {
827  // This pattern recursively unpacks one dimension at a time. The recursion
828  // bounded as the rank is strictly decreasing.
829  setHasBoundedRewriteRecursion();
830  }
831 
832  LogicalResult matchAndRewrite(FMAOp op,
833  PatternRewriter &rewriter) const override {
834  auto vType = op.getVectorType();
835  if (vType.getRank() < 2)
836  return failure();
837 
838  auto loc = op.getLoc();
839  auto elemType = vType.getElementType();
840  Value zero = rewriter.create<arith::ConstantOp>(
841  loc, elemType, rewriter.getZeroAttr(elemType));
842  Value desc = rewriter.create<vector::SplatOp>(loc, vType, zero);
843  for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
844  Value extrLHS = rewriter.create<ExtractOp>(loc, op.getLhs(), i);
845  Value extrRHS = rewriter.create<ExtractOp>(loc, op.getRhs(), i);
846  Value extrACC = rewriter.create<ExtractOp>(loc, op.getAcc(), i);
847  Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
848  desc = rewriter.create<InsertOp>(loc, fma, desc, i);
849  }
850  rewriter.replaceOp(op, desc);
851  return success();
852  }
853 };
854 
855 /// Returns the strides if the memory underlying `memRefType` has a contiguous
856 /// static layout.
858 computeContiguousStrides(MemRefType memRefType) {
859  int64_t offset;
860  SmallVector<int64_t, 4> strides;
861  if (failed(getStridesAndOffset(memRefType, strides, offset)))
862  return None;
863  if (!strides.empty() && strides.back() != 1)
864  return None;
865  // If no layout or identity layout, this is contiguous by definition.
866  if (memRefType.getLayout().isIdentity())
867  return strides;
868 
869  // Otherwise, we must determine contiguity form shapes. This can only ever
870  // work in static cases because MemRefType is underspecified to represent
871  // contiguous dynamic shapes in other ways than with just empty/identity
872  // layout.
873  auto sizes = memRefType.getShape();
874  for (int index = 0, e = strides.size() - 1; index < e; ++index) {
875  if (ShapedType::isDynamic(sizes[index + 1]) ||
876  ShapedType::isDynamicStrideOrOffset(strides[index]) ||
877  ShapedType::isDynamicStrideOrOffset(strides[index + 1]))
878  return None;
879  if (strides[index] != strides[index + 1] * sizes[index + 1])
880  return None;
881  }
882  return strides;
883 }
884 
885 class VectorTypeCastOpConversion
886  : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
887 public:
889 
891  matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
892  ConversionPatternRewriter &rewriter) const override {
893  auto loc = castOp->getLoc();
894  MemRefType sourceMemRefType =
895  castOp.getOperand().getType().cast<MemRefType>();
896  MemRefType targetMemRefType = castOp.getType();
897 
898  // Only static shape casts supported atm.
899  if (!sourceMemRefType.hasStaticShape() ||
900  !targetMemRefType.hasStaticShape())
901  return failure();
902 
903  auto llvmSourceDescriptorTy =
904  adaptor.getOperands()[0].getType().dyn_cast<LLVM::LLVMStructType>();
905  if (!llvmSourceDescriptorTy)
906  return failure();
907  MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
908 
909  auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType)
910  .dyn_cast_or_null<LLVM::LLVMStructType>();
911  if (!llvmTargetDescriptorTy)
912  return failure();
913 
914  // Only contiguous source buffers supported atm.
915  auto sourceStrides = computeContiguousStrides(sourceMemRefType);
916  if (!sourceStrides)
917  return failure();
918  auto targetStrides = computeContiguousStrides(targetMemRefType);
919  if (!targetStrides)
920  return failure();
921  // Only support static strides for now, regardless of contiguity.
922  if (llvm::any_of(*targetStrides, [](int64_t stride) {
923  return ShapedType::isDynamicStrideOrOffset(stride);
924  }))
925  return failure();
926 
927  auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
928 
929  // Create descriptor.
930  auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
931  Type llvmTargetElementTy = desc.getElementPtrType();
932  // Set allocated ptr.
933  Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
934  allocated =
935  rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
936  desc.setAllocatedPtr(rewriter, loc, allocated);
937  // Set aligned ptr.
938  Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
939  ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
940  desc.setAlignedPtr(rewriter, loc, ptr);
941  // Fill offset 0.
942  auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
943  auto zero = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, attr);
944  desc.setOffset(rewriter, loc, zero);
945 
946  // Fill size and stride descriptors in memref.
947  for (const auto &indexedSize :
948  llvm::enumerate(targetMemRefType.getShape())) {
949  int64_t index = indexedSize.index();
950  auto sizeAttr =
951  rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
952  auto size = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, sizeAttr);
953  desc.setSize(rewriter, loc, index, size);
954  auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
955  (*targetStrides)[index]);
956  auto stride = rewriter.create<LLVM::ConstantOp>(loc, int64Ty, strideAttr);
957  desc.setStride(rewriter, loc, index, stride);
958  }
959 
960  rewriter.replaceOp(castOp, {desc});
961  return success();
962  }
963 };
964 
965 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
966 /// Non-scalable versions of this operation are handled in Vector Transforms.
967 class VectorCreateMaskOpRewritePattern
968  : public OpRewritePattern<vector::CreateMaskOp> {
969 public:
970  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
971  bool enableIndexOpt)
973  force32BitVectorIndices(enableIndexOpt) {}
974 
975  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
976  PatternRewriter &rewriter) const override {
977  auto dstType = op.getType();
978  if (dstType.getRank() != 1 || !dstType.cast<VectorType>().isScalable())
979  return failure();
980  IntegerType idxType =
981  force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
982  auto loc = op->getLoc();
983  Value indices = rewriter.create<LLVM::StepVectorOp>(
984  loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
985  /*isScalable=*/true));
986  auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
987  op.getOperand(0));
988  Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
989  Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
990  indices, bounds);
991  rewriter.replaceOp(op, comp);
992  return success();
993  }
994 
995 private:
996  const bool force32BitVectorIndices;
997 };
998 
999 class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1000 public:
1002 
1003  // Proof-of-concept lowering implementation that relies on a small
1004  // runtime support library, which only needs to provide a few
1005  // printing methods (single value for all data types, opening/closing
1006  // bracket, comma, newline). The lowering fully unrolls a vector
1007  // in terms of these elementary printing operations. The advantage
1008  // of this approach is that the library can remain unaware of all
1009  // low-level implementation details of vectors while still supporting
1010  // output of any shaped and dimensioned vector. Due to full unrolling,
1011  // this approach is less suited for very large vectors though.
1012  //
1013  // TODO: rely solely on libc in future? something else?
1014  //
1016  matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1017  ConversionPatternRewriter &rewriter) const override {
1018  Type printType = printOp.getPrintType();
1019 
1020  if (typeConverter->convertType(printType) == nullptr)
1021  return failure();
1022 
1023  // Make sure element type has runtime support.
1024  PrintConversion conversion = PrintConversion::None;
1025  VectorType vectorType = printType.dyn_cast<VectorType>();
1026  Type eltType = vectorType ? vectorType.getElementType() : printType;
1027  Operation *printer;
1028  if (eltType.isF32()) {
1029  printer =
1030  LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
1031  } else if (eltType.isF64()) {
1032  printer =
1033  LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
1034  } else if (eltType.isIndex()) {
1035  printer =
1036  LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
1037  } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
1038  // Integers need a zero or sign extension on the operand
1039  // (depending on the source type) as well as a signed or
1040  // unsigned print method. Up to 64-bit is supported.
1041  unsigned width = intTy.getWidth();
1042  if (intTy.isUnsigned()) {
1043  if (width <= 64) {
1044  if (width < 64)
1045  conversion = PrintConversion::ZeroExt64;
1047  printOp->getParentOfType<ModuleOp>());
1048  } else {
1049  return failure();
1050  }
1051  } else {
1052  assert(intTy.isSignless() || intTy.isSigned());
1053  if (width <= 64) {
1054  // Note that we *always* zero extend booleans (1-bit integers),
1055  // so that true/false is printed as 1/0 rather than -1/0.
1056  if (width == 1)
1057  conversion = PrintConversion::ZeroExt64;
1058  else if (width < 64)
1059  conversion = PrintConversion::SignExt64;
1061  printOp->getParentOfType<ModuleOp>());
1062  } else {
1063  return failure();
1064  }
1065  }
1066  } else {
1067  return failure();
1068  }
1069 
1070  // Unroll vector into elementary print calls.
1071  int64_t rank = vectorType ? vectorType.getRank() : 0;
1072  Type type = vectorType ? vectorType : eltType;
1073  emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
1074  conversion);
1075  emitCall(rewriter, printOp->getLoc(),
1077  printOp->getParentOfType<ModuleOp>()));
1078  rewriter.eraseOp(printOp);
1079  return success();
1080  }
1081 
1082 private:
1083  enum class PrintConversion {
1084  // clang-format off
1085  None,
1086  ZeroExt64,
1087  SignExt64
1088  // clang-format on
1089  };
1090 
1091  void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
1092  Value value, Type type, Operation *printer, int64_t rank,
1093  PrintConversion conversion) const {
1094  VectorType vectorType = type.dyn_cast<VectorType>();
1095  Location loc = op->getLoc();
1096  if (!vectorType) {
1097  assert(rank == 0 && "The scalar case expects rank == 0");
1098  switch (conversion) {
1099  case PrintConversion::ZeroExt64:
1100  value = rewriter.create<arith::ExtUIOp>(
1101  loc, IntegerType::get(rewriter.getContext(), 64), value);
1102  break;
1103  case PrintConversion::SignExt64:
1104  value = rewriter.create<arith::ExtSIOp>(
1105  loc, IntegerType::get(rewriter.getContext(), 64), value);
1106  break;
1107  case PrintConversion::None:
1108  break;
1109  }
1110  emitCall(rewriter, loc, printer, value);
1111  return;
1112  }
1113 
1114  emitCall(rewriter, loc,
1118 
1119  if (rank <= 1) {
1120  auto reducedType = vectorType.getElementType();
1121  auto llvmType = typeConverter->convertType(reducedType);
1122  int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
1123  for (int64_t d = 0; d < dim; ++d) {
1124  Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1125  llvmType, /*rank=*/0, /*pos=*/d);
1126  emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
1127  conversion);
1128  if (d != dim - 1)
1129  emitCall(rewriter, loc, printComma);
1130  }
1131  emitCall(
1132  rewriter, loc,
1134  return;
1135  }
1136 
1137  int64_t dim = vectorType.getDimSize(0);
1138  for (int64_t d = 0; d < dim; ++d) {
1139  auto reducedType = reducedVectorTypeFront(vectorType);
1140  auto llvmType = typeConverter->convertType(reducedType);
1141  Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
1142  llvmType, rank, d);
1143  emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
1144  conversion);
1145  if (d != dim - 1)
1146  emitCall(rewriter, loc, printComma);
1147  }
1148  emitCall(rewriter, loc,
1150  }
1151 
1152  // Helper to emit a call.
1153  static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1154  Operation *ref, ValueRange params = ValueRange()) {
1155  rewriter.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(ref),
1156  params);
1157  }
1158 };
1159 
1160 /// The Splat operation is lowered to an insertelement + a shufflevector
1161 /// operation. Splat to only 0-d and 1-d vector result types are lowered.
1162 struct VectorSplatOpLowering : public ConvertOpToLLVMPattern<vector::SplatOp> {
1164 
1166  matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
1167  ConversionPatternRewriter &rewriter) const override {
1168  VectorType resultType = splatOp.getType().cast<VectorType>();
1169  if (resultType.getRank() > 1)
1170  return failure();
1171 
1172  // First insert it into an undef vector so we can shuffle it.
1173  auto vectorType = typeConverter->convertType(splatOp.getType());
1174  Value undef = rewriter.create<LLVM::UndefOp>(splatOp.getLoc(), vectorType);
1175  auto zero = rewriter.create<LLVM::ConstantOp>(
1176  splatOp.getLoc(),
1177  typeConverter->convertType(rewriter.getIntegerType(32)),
1178  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1179 
1180  // For 0-d vector, we simply do `insertelement`.
1181  if (resultType.getRank() == 0) {
1182  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1183  splatOp, vectorType, undef, adaptor.getInput(), zero);
1184  return success();
1185  }
1186 
1187  // For 1-d vector, we additionally do a `vectorshuffle`.
1188  auto v = rewriter.create<LLVM::InsertElementOp>(
1189  splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero);
1190 
1191  int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1192  SmallVector<int32_t, 4> zeroValues(width, 0);
1193 
1194  // Shuffle the value across the desired number of elements.
1195  ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1196  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(splatOp, v, undef,
1197  zeroAttrs);
1198  return success();
1199  }
1200 };
1201 
1202 /// The Splat operation is lowered to an insertelement + a shufflevector
1203 /// operation. Splat to only 2+-d vector result types are lowered by the
1204 /// SplatNdOpLowering, the 1-d case is handled by SplatOpLowering.
1205 struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
1207 
1209  matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor,
1210  ConversionPatternRewriter &rewriter) const override {
1211  VectorType resultType = splatOp.getType();
1212  if (resultType.getRank() <= 1)
1213  return failure();
1214 
1215  // First insert it into an undef vector so we can shuffle it.
1216  auto loc = splatOp.getLoc();
1217  auto vectorTypeInfo =
1218  LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1219  auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1220  auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1221  if (!llvmNDVectorTy || !llvm1DVectorTy)
1222  return failure();
1223 
1224  // Construct returned value.
1225  Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmNDVectorTy);
1226 
1227  // Construct a 1-D vector with the splatted value that we insert in all the
1228  // places within the returned descriptor.
1229  Value vdesc = rewriter.create<LLVM::UndefOp>(loc, llvm1DVectorTy);
1230  auto zero = rewriter.create<LLVM::ConstantOp>(
1231  loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1232  rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1233  Value v = rewriter.create<LLVM::InsertElementOp>(loc, llvm1DVectorTy, vdesc,
1234  adaptor.getInput(), zero);
1235 
1236  // Shuffle the value across the desired number of elements.
1237  int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1238  SmallVector<int32_t, 4> zeroValues(width, 0);
1239  ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1240  v = rewriter.create<LLVM::ShuffleVectorOp>(loc, v, v, zeroAttrs);
1241 
1242  // Iterate of linear index, convert to coords space and insert splatted 1-D
1243  // vector in each position.
1244  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
1245  desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmNDVectorTy, desc, v,
1246  position);
1247  });
1248  rewriter.replaceOp(splatOp, desc);
1249  return success();
1250  }
1251 };
1252 
1253 } // namespace
1254 
1255 /// Populate the given list with patterns that convert from Vector to LLVM.
1257  LLVMTypeConverter &converter, RewritePatternSet &patterns,
1258  bool reassociateFPReductions, bool force32BitVectorIndices) {
1259  MLIRContext *ctx = converter.getDialect()->getContext();
1260  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
1262  patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
1263  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
1264  patterns
1265  .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
1266  VectorExtractElementOpConversion, VectorExtractOpConversion,
1267  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
1268  VectorInsertOpConversion, VectorPrintOpConversion,
1269  VectorTypeCastOpConversion, VectorScaleOpConversion,
1270  VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
1271  VectorLoadStoreConversion<vector::MaskedLoadOp,
1272  vector::MaskedLoadOpAdaptor>,
1273  VectorLoadStoreConversion<vector::StoreOp, vector::StoreOpAdaptor>,
1274  VectorLoadStoreConversion<vector::MaskedStoreOp,
1275  vector::MaskedStoreOpAdaptor>,
1276  VectorGatherOpConversion, VectorScatterOpConversion,
1277  VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
1278  VectorSplatOpLowering, VectorSplatNdOpLowering>(converter);
1279  // Transfer ops with rank > 1 are handled by VectorToSCF.
1280  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
1281 }
1282 
1284  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1285  patterns.add<VectorMatmulOpConversion>(converter);
1286  patterns.add<VectorFlatTransposeOpConversion>(converter);
1287 }
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter)
MLIRContext * getContext() const
Definition: Builders.h:54
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:908
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp)
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:188
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:893
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition: TypeToLLVM.h:39
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Value insertOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:688
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
static VectorType reducedVectorTypeBack(VectorType tp)
static VectorType reducedVectorTypeFront(VectorType tp)
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:176
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp)
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:148
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp)
Helper functions to lookup or create the declaration for commonly used external C function calls...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateVectorInsertExtractStridedSliceTransforms(RewritePatternSet &patterns)
Populate patterns with the following patterns.
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp)
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:256
MLIR_CRUNNERUTILS_EXPORT void printComma()
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
const llvm::DataLayout & getDataLayout()
Returns the data layout to use during and after conversion.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
Definition: LLVMTypes.cpp:188
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:215
static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating an undef value of the descriptor type.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:161
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, Value memref, Value base, Value index, MemRefType memRefType, VectorType vType, Value &ptrs)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LLVM dialect array type.
Definition: LLVMTypes.h:75
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp)
IntegerType getI64Type()
Definition: Builders.cpp:56
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:62
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:66
static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp)
Type getType() const
Return the type of this value.
Definition: Value.h:118
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Do not split vector transfer operations.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr, MemRefType memRefType, Type vt)
LLVM dialect structure type representing a collection of different-typed elements manipulated togethe...
Definition: LLVMTypes.h:277
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayAttr)> fun)
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, llvm::Optional< unsigned > maxTransferRank=llvm::None)
Collect a set of transfer read/write lowering patterns.
This class implements a pattern rewriter for use with ConversionPatterns.
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition: Types.cpp:85
U cast() const
Definition: Value.h:108
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
Definition: TypeToLLVM.cpp:191
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
bool isa() const
Definition: Types.h:246
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp)
This class provides an abstraction over the different types of ranges over Values.
IntegerType getI32Type()
Definition: Builders.cpp:54
LLVM::LLVMDialect * getDialect()
Returns the LLVM dialect.
Definition: TypeConverter.h:79
U cast() const
Definition: Types.h:262