MLIR 22.0.0git
ConvertVectorToLLVM.cpp
Go to the documentation of this file.
1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
31#include "llvm/ADT/APFloat.h"
32#include "llvm/IR/LLVMContext.h"
33#include "llvm/Support/Casting.h"
34
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::vector;
39
40// Helper that picks the proper sequence for inserting.
41static Value insertOne(ConversionPatternRewriter &rewriter,
42 const LLVMTypeConverter &typeConverter, Location loc,
43 Value val1, Value val2, Type llvmType, int64_t rank,
44 int64_t pos) {
45 assert(rank > 0 && "0-D vector corner case should have been handled already");
46 if (rank == 1) {
47 auto idxType = rewriter.getIndexType();
48 auto constant = LLVM::ConstantOp::create(
49 rewriter, loc, typeConverter.convertType(idxType),
50 rewriter.getIntegerAttr(idxType, pos));
51 return LLVM::InsertElementOp::create(rewriter, loc, llvmType, val1, val2,
52 constant);
53 }
54 return LLVM::InsertValueOp::create(rewriter, loc, val1, val2, pos);
55}
56
57// Helper that picks the proper sequence for extracting.
58static Value extractOne(ConversionPatternRewriter &rewriter,
59 const LLVMTypeConverter &typeConverter, Location loc,
60 Value val, Type llvmType, int64_t rank, int64_t pos) {
61 if (rank <= 1) {
62 auto idxType = rewriter.getIndexType();
63 auto constant = LLVM::ConstantOp::create(
64 rewriter, loc, typeConverter.convertType(idxType),
65 rewriter.getIntegerAttr(idxType, pos));
66 return LLVM::ExtractElementOp::create(rewriter, loc, llvmType, val,
67 constant);
68 }
69 return LLVM::ExtractValueOp::create(rewriter, loc, val, pos);
70}
71
72// Helper that returns data layout alignment of a vector.
73LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
74 VectorType vectorType, unsigned &align) {
75 Type convertedVectorTy = typeConverter.convertType(vectorType);
76 if (!convertedVectorTy)
77 return failure();
78
79 llvm::LLVMContext llvmContext;
80 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
81 .getPreferredAlignment(convertedVectorTy,
82 typeConverter.getDataLayout());
83
84 return success();
85}
86
87// Helper that returns data layout alignment of a memref.
88LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
89 MemRefType memrefType, unsigned &align) {
90 Type elementTy = typeConverter.convertType(memrefType.getElementType());
91 if (!elementTy)
92 return failure();
93
94 // TODO: this should use the MLIR data layout when it becomes available and
95 // stop depending on translation.
96 llvm::LLVMContext llvmContext;
97 align = LLVM::TypeToLLVMIRTranslator(llvmContext)
98 .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
99 return success();
100}
101
102// Helper to resolve the alignment for vector load/store, gather and scatter
103// ops. If useVectorAlignment is true, get the preferred alignment for the
104// vector type in the operation. This option is used for hardware backends with
105// vectorization. Otherwise, use the preferred alignment of the element type of
106// the memref. Note that if you choose to use vector alignment, the shape of the
107// vector type must be resolved before the ConvertVectorToLLVM pass is run.
108LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
109 VectorType vectorType,
110 MemRefType memrefType, unsigned &align,
111 bool useVectorAlignment) {
112 if (useVectorAlignment) {
113 if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
114 return failure();
115 }
116 } else {
117 if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
118 return failure();
119 }
120 }
121 return success();
122}
123
124// Check if the last stride is non-unit and has a valid memory space.
125static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
126 const LLVMTypeConverter &converter) {
127 if (!memRefType.isLastDimUnitStride())
128 return failure();
129 if (failed(converter.getMemRefAddressSpace(memRefType)))
130 return failure();
131 return success();
132}
133
134// Add an index vector component to a base pointer.
135static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
136 const LLVMTypeConverter &typeConverter,
137 MemRefType memRefType, Value llvmMemref, Value base,
138 Value index, VectorType vectorType) {
139 assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
140 "unsupported memref type");
141 assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
142 auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
143 auto ptrsType =
144 LLVM::getVectorType(pType, vectorType.getDimSize(0),
145 /*isScalable=*/vectorType.getScalableDims()[0]);
146 return LLVM::GEPOp::create(
147 rewriter, loc, ptrsType,
148 typeConverter.convertType(memRefType.getElementType()), base, index);
149}
150
151/// Convert `foldResult` into a Value. Integer attribute is converted to
152/// an LLVM constant op.
154 OpFoldResult foldResult) {
155 if (auto attr = dyn_cast<Attribute>(foldResult)) {
156 auto intAttr = cast<IntegerAttr>(attr);
157 return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
158 }
159
160 return cast<Value>(foldResult);
161}
162
163namespace {
164
165/// Trivial Vector to LLVM conversions
166using VectorScaleOpConversion =
168
169/// Conversion pattern for a vector.bitcast.
170class VectorBitCastOpConversion
171 : public ConvertOpToLLVMPattern<vector::BitCastOp> {
172public:
173 using ConvertOpToLLVMPattern<vector::BitCastOp>::ConvertOpToLLVMPattern;
174
175 LogicalResult
176 matchAndRewrite(vector::BitCastOp bitCastOp, OpAdaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 // Only 0-D and 1-D vectors can be lowered to LLVM.
179 VectorType resultTy = bitCastOp.getResultVectorType();
180 if (resultTy.getRank() > 1)
181 return failure();
182 Type newResultTy = typeConverter->convertType(resultTy);
183 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(bitCastOp, newResultTy,
184 adaptor.getOperands()[0]);
185 return success();
186 }
187};
188
189/// Overloaded utility that replaces a vector.load, vector.store,
190/// vector.maskedload and vector.maskedstore with their respective LLVM
191/// couterparts.
192static void replaceLoadOrStoreOp(vector::LoadOp loadOp,
193 vector::LoadOpAdaptor adaptor,
194 VectorType vectorTy, Value ptr, unsigned align,
195 ConversionPatternRewriter &rewriter) {
196 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, vectorTy, ptr, align,
197 /*volatile_=*/false,
198 loadOp.getNontemporal());
199}
200
201static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp,
202 vector::MaskedLoadOpAdaptor adaptor,
203 VectorType vectorTy, Value ptr, unsigned align,
204 ConversionPatternRewriter &rewriter) {
205 rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
206 loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align);
207}
208
209static void replaceLoadOrStoreOp(vector::StoreOp storeOp,
210 vector::StoreOpAdaptor adaptor,
211 VectorType vectorTy, Value ptr, unsigned align,
212 ConversionPatternRewriter &rewriter) {
213 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValueToStore(),
214 ptr, align, /*volatile_=*/false,
215 storeOp.getNontemporal());
216}
217
218static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
219 vector::MaskedStoreOpAdaptor adaptor,
220 VectorType vectorTy, Value ptr, unsigned align,
221 ConversionPatternRewriter &rewriter) {
222 rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
223 storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align);
224}
225
226/// Conversion pattern for a vector.load, vector.store, vector.maskedload, and
227/// vector.maskedstore.
228template <class LoadOrStoreOp>
229class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
230public:
231 explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
232 bool useVectorAlign)
233 : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
234 useVectorAlignment(useVectorAlign) {}
235 using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
236
237 LogicalResult
238 matchAndRewrite(LoadOrStoreOp loadOrStoreOp,
239 typename LoadOrStoreOp::Adaptor adaptor,
240 ConversionPatternRewriter &rewriter) const override {
241 // Only 1-D vectors can be lowered to LLVM.
242 VectorType vectorTy = loadOrStoreOp.getVectorType();
243 if (vectorTy.getRank() > 1)
244 return failure();
245
246 auto loc = loadOrStoreOp->getLoc();
247 MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248
249 // Resolve alignment.
250 // Explicit alignment takes priority over use-vector-alignment.
251 unsigned align = loadOrStoreOp.getAlignment().value_or(0);
252 if (!align &&
253 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
254 memRefTy, align, useVectorAlignment)))
255 return rewriter.notifyMatchFailure(loadOrStoreOp,
256 "could not resolve alignment");
257
258 // Resolve address.
259 auto vtype = cast<VectorType>(
260 this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
261 Value dataPtr = this->getStridedElementPtr(
262 rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
263 replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
264 rewriter);
265 return success();
266 }
267
268private:
269 // If true, use the preferred alignment of the vector type.
270 // If false, use the preferred alignment of the element type
271 // of the memref. This flag is intended for use with hardware
272 // backends that require alignment of vector operations.
273 const bool useVectorAlignment;
274};
275
276/// Conversion pattern for a vector.gather.
277class VectorGatherOpConversion
278 : public ConvertOpToLLVMPattern<vector::GatherOp> {
279public:
280 explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
281 bool useVectorAlign)
282 : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
283 useVectorAlignment(useVectorAlign) {}
284 using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
285
286 LogicalResult
287 matchAndRewrite(vector::GatherOp gather, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289 Location loc = gather->getLoc();
290 MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
291 assert(memRefType && "The base should be bufferized");
292
293 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
294 return rewriter.notifyMatchFailure(gather, "memref type not supported");
295
296 VectorType vType = gather.getVectorType();
297 if (vType.getRank() > 1) {
298 return rewriter.notifyMatchFailure(
299 gather, "only 1-D vectors can be lowered to LLVM");
300 }
301
302 // Resolve alignment.
303 // Explicit alignment takes priority over use-vector-alignment.
304 unsigned align = gather.getAlignment().value_or(0);
305 if (!align &&
306 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
307 memRefType, align, useVectorAlignment)))
308 return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
309
310 // Resolve address.
311 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
312 adaptor.getBase(), adaptor.getOffsets());
313 Value base = adaptor.getBase();
314 Value ptrs =
315 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
316 base, ptr, adaptor.getIndices(), vType);
317
318 // Replace with the gather intrinsic.
319 rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
320 gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
321 adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
322 return success();
323 }
324
325private:
326 // If true, use the preferred alignment of the vector type.
327 // If false, use the preferred alignment of the element type
328 // of the memref. This flag is intended for use with hardware
329 // backends that require alignment of vector operations.
330 const bool useVectorAlignment;
331};
332
333/// Conversion pattern for a vector.scatter.
334class VectorScatterOpConversion
335 : public ConvertOpToLLVMPattern<vector::ScatterOp> {
336public:
337 explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
338 bool useVectorAlign)
339 : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
340 useVectorAlignment(useVectorAlign) {}
341
342 using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
343
344 LogicalResult
345 matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter) const override {
347 auto loc = scatter->getLoc();
348 auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
349 assert(memRefType && "The base should be bufferized");
350
351 if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
352 return rewriter.notifyMatchFailure(scatter, "memref type not supported");
353
354 VectorType vType = scatter.getVectorType();
355 if (vType.getRank() > 1) {
356 return rewriter.notifyMatchFailure(
357 scatter, "only 1-D vectors can be lowered to LLVM");
358 }
359
360 // Resolve alignment.
361 // Explicit alignment takes priority over use-vector-alignment.
362 unsigned align = scatter.getAlignment().value_or(0);
363 if (!align &&
364 failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
365 memRefType, align, useVectorAlignment)))
366 return rewriter.notifyMatchFailure(scatter,
367 "could not resolve alignment");
368
369 // Resolve address.
370 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
371 adaptor.getBase(), adaptor.getOffsets());
372 Value ptrs =
373 getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
374 adaptor.getBase(), ptr, adaptor.getIndices(), vType);
375
376 // Replace with the scatter intrinsic.
377 rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
378 scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(),
379 rewriter.getI32IntegerAttr(align));
380 return success();
381 }
382
383private:
384 // If true, use the preferred alignment of the vector type.
385 // If false, use the preferred alignment of the element type
386 // of the memref. This flag is intended for use with hardware
387 // backends that require alignment of vector operations.
388 const bool useVectorAlignment;
389};
390
391/// Conversion pattern for a vector.expandload.
392class VectorExpandLoadOpConversion
393 : public ConvertOpToLLVMPattern<vector::ExpandLoadOp> {
394public:
395 using ConvertOpToLLVMPattern<vector::ExpandLoadOp>::ConvertOpToLLVMPattern;
396
397 LogicalResult
398 matchAndRewrite(vector::ExpandLoadOp expand, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter) const override {
400 auto loc = expand->getLoc();
401 MemRefType memRefType = expand.getMemRefType();
402
403 // Resolve address.
404 auto vtype = typeConverter->convertType(expand.getVectorType());
405 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
406 adaptor.getBase(), adaptor.getIndices());
407
408 // From:
409 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
410 // The pointer alignment defaults to 1.
411 uint64_t alignment = expand.getAlignment().value_or(1);
412
413 rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
414 expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
415 alignment);
416 return success();
417 }
418};
419
420/// Conversion pattern for a vector.compressstore.
421class VectorCompressStoreOpConversion
422 : public ConvertOpToLLVMPattern<vector::CompressStoreOp> {
423public:
424 using ConvertOpToLLVMPattern<vector::CompressStoreOp>::ConvertOpToLLVMPattern;
425
426 LogicalResult
427 matchAndRewrite(vector::CompressStoreOp compress, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const override {
429 auto loc = compress->getLoc();
430 MemRefType memRefType = compress.getMemRefType();
431
432 // Resolve address.
433 Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
434 adaptor.getBase(), adaptor.getIndices());
435
436 // From:
437 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
438 // The pointer alignment defaults to 1.
439 uint64_t alignment = compress.getAlignment().value_or(1);
440
441 rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
442 compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
443 return success();
444 }
445};
446
447/// Reduction neutral classes for overloading.
448class ReductionNeutralZero {};
449class ReductionNeutralIntOne {};
450class ReductionNeutralFPOne {};
451class ReductionNeutralAllOnes {};
452class ReductionNeutralSIntMin {};
453class ReductionNeutralUIntMin {};
454class ReductionNeutralSIntMax {};
455class ReductionNeutralUIntMax {};
456class ReductionNeutralFPMin {};
457class ReductionNeutralFPMax {};
458
459/// Create the reduction neutral zero value.
460static Value createReductionNeutralValue(ReductionNeutralZero neutral,
461 ConversionPatternRewriter &rewriter,
462 Location loc, Type llvmType) {
463 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
464 rewriter.getZeroAttr(llvmType));
465}
466
467/// Create the reduction neutral integer one value.
468static Value createReductionNeutralValue(ReductionNeutralIntOne neutral,
469 ConversionPatternRewriter &rewriter,
470 Location loc, Type llvmType) {
471 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
472 rewriter.getIntegerAttr(llvmType, 1));
473}
474
475/// Create the reduction neutral fp one value.
476static Value createReductionNeutralValue(ReductionNeutralFPOne neutral,
477 ConversionPatternRewriter &rewriter,
478 Location loc, Type llvmType) {
479 return LLVM::ConstantOp::create(rewriter, loc, llvmType,
480 rewriter.getFloatAttr(llvmType, 1.0));
481}
482
483/// Create the reduction neutral all-ones value.
484static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral,
485 ConversionPatternRewriter &rewriter,
486 Location loc, Type llvmType) {
487 return LLVM::ConstantOp::create(
488 rewriter, loc, llvmType,
489 rewriter.getIntegerAttr(
490 llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth())));
491}
492
493/// Create the reduction neutral signed int minimum value.
494static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral,
495 ConversionPatternRewriter &rewriter,
496 Location loc, Type llvmType) {
497 return LLVM::ConstantOp::create(
498 rewriter, loc, llvmType,
499 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue(
500 llvmType.getIntOrFloatBitWidth())));
501}
502
503/// Create the reduction neutral unsigned int minimum value.
504static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral,
505 ConversionPatternRewriter &rewriter,
506 Location loc, Type llvmType) {
507 return LLVM::ConstantOp::create(
508 rewriter, loc, llvmType,
509 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue(
510 llvmType.getIntOrFloatBitWidth())));
511}
512
513/// Create the reduction neutral signed int maximum value.
514static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral,
515 ConversionPatternRewriter &rewriter,
516 Location loc, Type llvmType) {
517 return LLVM::ConstantOp::create(
518 rewriter, loc, llvmType,
519 rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue(
520 llvmType.getIntOrFloatBitWidth())));
521}
522
523/// Create the reduction neutral unsigned int maximum value.
524static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral,
525 ConversionPatternRewriter &rewriter,
526 Location loc, Type llvmType) {
527 return LLVM::ConstantOp::create(
528 rewriter, loc, llvmType,
529 rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue(
530 llvmType.getIntOrFloatBitWidth())));
531}
532
533/// Create the reduction neutral fp minimum value.
534static Value createReductionNeutralValue(ReductionNeutralFPMin neutral,
535 ConversionPatternRewriter &rewriter,
536 Location loc, Type llvmType) {
537 auto floatType = cast<FloatType>(llvmType);
538 return LLVM::ConstantOp::create(
539 rewriter, loc, llvmType,
540 rewriter.getFloatAttr(
541 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
542 /*Negative=*/false)));
543}
544
545/// Create the reduction neutral fp maximum value.
546static Value createReductionNeutralValue(ReductionNeutralFPMax neutral,
547 ConversionPatternRewriter &rewriter,
548 Location loc, Type llvmType) {
549 auto floatType = cast<FloatType>(llvmType);
550 return LLVM::ConstantOp::create(
551 rewriter, loc, llvmType,
552 rewriter.getFloatAttr(
553 llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(),
554 /*Negative=*/true)));
555}
556
557/// Returns `accumulator` if it has a valid value. Otherwise, creates and
558/// returns a new accumulator value using `ReductionNeutral`.
559template <class ReductionNeutral>
560static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
561 Location loc, Type llvmType,
562 Value accumulator) {
563 if (accumulator)
564 return accumulator;
565
566 return createReductionNeutralValue(ReductionNeutral(), rewriter, loc,
567 llvmType);
568}
569
570/// Creates a value with the 1-D vector shape provided in `llvmType`.
571/// This is used as effective vector length by some intrinsics supporting
572/// dynamic vector lengths at runtime.
573static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
574 Location loc, Type llvmType) {
575 VectorType vType = cast<VectorType>(llvmType);
576 auto vShape = vType.getShape();
577 assert(vShape.size() == 1 && "Unexpected multi-dim vector type");
578
579 Value baseVecLength = LLVM::ConstantOp::create(
580 rewriter, loc, rewriter.getI32Type(),
581 rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));
582
583 if (!vType.getScalableDims()[0])
584 return baseVecLength;
585
586 // For a scalable vector type, create and return `vScale * baseVecLength`.
587 Value vScale = vector::VectorScaleOp::create(rewriter, loc);
588 vScale =
589 arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), vScale);
590 Value scalableVecLength =
591 arith::MulIOp::create(rewriter, loc, baseVecLength, vScale);
592 return scalableVecLength;
593}
594
595/// Helper method to lower a `vector.reduction` op that performs an arithmetic
596/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
597/// and `ScalarOp` is the scalar operation used to add the accumulation value if
598/// non-null.
599template <class LLVMRedIntrinOp, class ScalarOp>
600static Value createIntegerReductionArithmeticOpLowering(
601 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
602 Value vectorOperand, Value accumulator) {
603
604 Value result =
605 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
606
607 if (accumulator)
608 result = ScalarOp::create(rewriter, loc, accumulator, result);
609 return result;
610}
611
612/// Helper method to lower a `vector.reduction` operation that performs
613/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
614/// intrinsic to use and `predicate` is the predicate to use to compare+combine
615/// the accumulator value if non-null.
616template <class LLVMRedIntrinOp>
617static Value createIntegerReductionComparisonOpLowering(
618 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
619 Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
620 Value result =
621 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand);
622 if (accumulator) {
623 Value cmp =
624 LLVM::ICmpOp::create(rewriter, loc, predicate, accumulator, result);
625 result = LLVM::SelectOp::create(rewriter, loc, cmp, accumulator, result);
626 }
627 return result;
628}
629
630namespace {
631template <typename Source>
632struct VectorToScalarMapper;
633template <>
634struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
635 using Type = LLVM::MaximumOp;
636};
637template <>
638struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
639 using Type = LLVM::MinimumOp;
640};
641template <>
642struct VectorToScalarMapper<LLVM::vector_reduce_fmax> {
643 using Type = LLVM::MaxNumOp;
644};
645template <>
646struct VectorToScalarMapper<LLVM::vector_reduce_fmin> {
647 using Type = LLVM::MinNumOp;
648};
649} // namespace
650
651template <class LLVMRedIntrinOp>
652static Value createFPReductionComparisonOpLowering(
653 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
654 Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) {
655 Value result =
656 LLVMRedIntrinOp::create(rewriter, loc, llvmType, vectorOperand, fmf);
657
658 if (accumulator) {
659 result = VectorToScalarMapper<LLVMRedIntrinOp>::Type::create(
660 rewriter, loc, result, accumulator);
661 }
662
663 return result;
664}
665
666/// Reduction neutral classes for overloading
667class MaskNeutralFMaximum {};
668class MaskNeutralFMinimum {};
669
670/// Get the mask neutral floating point maximum value
671static llvm::APFloat
672getMaskNeutralValue(MaskNeutralFMaximum,
673 const llvm::fltSemantics &floatSemantics) {
674 return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true);
675}
676/// Get the mask neutral floating point minimum value
677static llvm::APFloat
678getMaskNeutralValue(MaskNeutralFMinimum,
679 const llvm::fltSemantics &floatSemantics) {
680 return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false);
681}
682
683/// Create the mask neutral floating point MLIR vector constant
684template <typename MaskNeutral>
685static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
686 Location loc, Type llvmType,
687 Type vectorType) {
688 const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
689 auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
690 auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
691 return LLVM::ConstantOp::create(rewriter, loc, vectorType, denseValue);
692}
693
694/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked
695/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for
696/// `fmaximum`/`fminimum`.
697/// More information: https://github.com/llvm/llvm-project/issues/64940
698template <class LLVMRedIntrinOp, class MaskNeutral>
699static Value
700lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter,
701 Location loc, Type llvmType,
702 Value vectorOperand, Value accumulator,
703 Value mask, LLVM::FastmathFlagsAttr fmf) {
704 const Value vectorMaskNeutral = createMaskNeutralValue<MaskNeutral>(
705 rewriter, loc, llvmType, vectorOperand.getType());
706 const Value selectedVectorByMask = LLVM::SelectOp::create(
707 rewriter, loc, mask, vectorOperand, vectorMaskNeutral);
708 return createFPReductionComparisonOpLowering<LLVMRedIntrinOp>(
709 rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf);
710}
711
712template <class LLVMRedIntrinOp, class ReductionNeutral>
713static Value
714lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
715 Type llvmType, Value vectorOperand,
716 Value accumulator, LLVM::FastmathFlagsAttr fmf) {
717 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
718 llvmType, accumulator);
719 return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
720 /*start_value=*/accumulator, vectorOperand,
721 fmf);
722}
723
724/// Overloaded methods to lower a *predicated* reduction to an llvm intrinsic
725/// that requires a start value. This start value format spans across fp
726/// reductions without mask and all the masked reduction intrinsics.
727template <class LLVMVPRedIntrinOp, class ReductionNeutral>
728static Value
729lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter,
730 Location loc, Type llvmType,
731 Value vectorOperand, Value accumulator) {
732 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
733 llvmType, accumulator);
734 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
735 /*startValue=*/accumulator, vectorOperand);
736}
737
738template <class LLVMVPRedIntrinOp, class ReductionNeutral>
739static Value lowerPredicatedReductionWithStartValue(
740 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
741 Value vectorOperand, Value accumulator, Value mask) {
742 accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
743 llvmType, accumulator);
744 Value vectorLength =
745 createVectorLengthValue(rewriter, loc, vectorOperand.getType());
746 return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
747 /*satrt_value=*/accumulator, vectorOperand,
748 mask, vectorLength);
749}
750
751template <class LLVMIntVPRedIntrinOp, class IntReductionNeutral,
752 class LLVMFPVPRedIntrinOp, class FPReductionNeutral>
753static Value lowerPredicatedReductionWithStartValue(
754 ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
755 Value vectorOperand, Value accumulator, Value mask) {
756 if (llvmType.isIntOrIndex())
757 return lowerPredicatedReductionWithStartValue<LLVMIntVPRedIntrinOp,
758 IntReductionNeutral>(
759 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
760
761 // FP dispatch.
762 return lowerPredicatedReductionWithStartValue<LLVMFPVPRedIntrinOp,
763 FPReductionNeutral>(
764 rewriter, loc, llvmType, vectorOperand, accumulator, mask);
765}
766
767/// Conversion pattern for all vector reductions.
768class VectorReductionOpConversion
769 : public ConvertOpToLLVMPattern<vector::ReductionOp> {
770public:
771 explicit VectorReductionOpConversion(const LLVMTypeConverter &typeConv,
772 bool reassociateFPRed)
773 : ConvertOpToLLVMPattern<vector::ReductionOp>(typeConv),
774 reassociateFPReductions(reassociateFPRed) {}
775
776 LogicalResult
777 matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor,
778 ConversionPatternRewriter &rewriter) const override {
779 auto kind = reductionOp.getKind();
780 Type eltType = reductionOp.getDest().getType();
781 Type llvmType = typeConverter->convertType(eltType);
782 Value operand = adaptor.getVector();
783 Value acc = adaptor.getAcc();
784 Location loc = reductionOp.getLoc();
785
786 if (eltType.isIntOrIndex()) {
787 // Integer reductions: add/mul/min/max/and/or/xor.
788 Value result;
789 switch (kind) {
790 case vector::CombiningKind::ADD:
791 result =
792 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
793 LLVM::AddOp>(
794 rewriter, loc, llvmType, operand, acc);
795 break;
796 case vector::CombiningKind::MUL:
797 result =
798 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
799 LLVM::MulOp>(
800 rewriter, loc, llvmType, operand, acc);
801 break;
802 case vector::CombiningKind::MINUI:
803 result = createIntegerReductionComparisonOpLowering<
804 LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
805 LLVM::ICmpPredicate::ule);
806 break;
807 case vector::CombiningKind::MINSI:
808 result = createIntegerReductionComparisonOpLowering<
809 LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
810 LLVM::ICmpPredicate::sle);
811 break;
812 case vector::CombiningKind::MAXUI:
813 result = createIntegerReductionComparisonOpLowering<
814 LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
815 LLVM::ICmpPredicate::uge);
816 break;
817 case vector::CombiningKind::MAXSI:
818 result = createIntegerReductionComparisonOpLowering<
819 LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
820 LLVM::ICmpPredicate::sge);
821 break;
822 case vector::CombiningKind::AND:
823 result =
824 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
825 LLVM::AndOp>(
826 rewriter, loc, llvmType, operand, acc);
827 break;
828 case vector::CombiningKind::OR:
829 result =
830 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
831 LLVM::OrOp>(
832 rewriter, loc, llvmType, operand, acc);
833 break;
834 case vector::CombiningKind::XOR:
835 result =
836 createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
837 LLVM::XOrOp>(
838 rewriter, loc, llvmType, operand, acc);
839 break;
840 default:
841 return failure();
842 }
843 rewriter.replaceOp(reductionOp, result);
844
845 return success();
846 }
847
848 if (!isa<FloatType>(eltType))
849 return failure();
850
851 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
852 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
853 reductionOp.getContext(),
854 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
855 fmf = LLVM::FastmathFlagsAttr::get(
856 reductionOp.getContext(),
857 fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc
858 : LLVM::FastmathFlags::none));
859
860 // Floating-point reductions: add/mul/min/max
861 Value result;
862 if (kind == vector::CombiningKind::ADD) {
863 result = lowerReductionWithStartValue<LLVM::vector_reduce_fadd,
864 ReductionNeutralZero>(
865 rewriter, loc, llvmType, operand, acc, fmf);
866 } else if (kind == vector::CombiningKind::MUL) {
867 result = lowerReductionWithStartValue<LLVM::vector_reduce_fmul,
868 ReductionNeutralFPOne>(
869 rewriter, loc, llvmType, operand, acc, fmf);
870 } else if (kind == vector::CombiningKind::MINIMUMF) {
871 result =
872 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
873 rewriter, loc, llvmType, operand, acc, fmf);
874 } else if (kind == vector::CombiningKind::MAXIMUMF) {
875 result =
876 createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
877 rewriter, loc, llvmType, operand, acc, fmf);
878 } else if (kind == vector::CombiningKind::MINNUMF) {
879 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
880 rewriter, loc, llvmType, operand, acc, fmf);
881 } else if (kind == vector::CombiningKind::MAXNUMF) {
882 result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
883 rewriter, loc, llvmType, operand, acc, fmf);
884 } else {
885 return failure();
886 }
887
888 rewriter.replaceOp(reductionOp, result);
889 return success();
890 }
891
892private:
893 const bool reassociateFPReductions;
894};
895
896/// Base class to convert a `vector.mask` operation while matching traits
897/// of the maskable operation nested inside. A `VectorMaskOpConversionBase`
898/// instance matches against a `vector.mask` operation. The `matchAndRewrite`
899/// method performs a second match against the maskable operation `MaskedOp`.
900/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be
901/// implemented by the concrete conversion classes. This method can match
902/// against specific traits of the `vector.mask` and the maskable operation. It
903/// must replace the `vector.mask` operation.
904template <class MaskedOp>
905class VectorMaskOpConversionBase
906 : public ConvertOpToLLVMPattern<vector::MaskOp> {
907public:
908 using ConvertOpToLLVMPattern<vector::MaskOp>::ConvertOpToLLVMPattern;
909
910 LogicalResult
911 matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
912 ConversionPatternRewriter &rewriter) const final {
913 // Match against the maskable operation kind.
914 auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
915 if (!maskedOp)
916 return failure();
917 return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
918 }
919
920protected:
921 virtual LogicalResult
922 matchAndRewriteMaskableOp(vector::MaskOp maskOp,
923 vector::MaskableOpInterface maskableOp,
924 ConversionPatternRewriter &rewriter) const = 0;
925};
926
927class MaskedReductionOpConversion
928 : public VectorMaskOpConversionBase<vector::ReductionOp> {
929
930public:
931 using VectorMaskOpConversionBase<
932 vector::ReductionOp>::VectorMaskOpConversionBase;
933
934 LogicalResult matchAndRewriteMaskableOp(
935 vector::MaskOp maskOp, MaskableOpInterface maskableOp,
936 ConversionPatternRewriter &rewriter) const override {
937 auto reductionOp = cast<ReductionOp>(maskableOp.getOperation());
938 auto kind = reductionOp.getKind();
939 Type eltType = reductionOp.getDest().getType();
940 Type llvmType = typeConverter->convertType(eltType);
941 Value operand = reductionOp.getVector();
942 Value acc = reductionOp.getAcc();
943 Location loc = reductionOp.getLoc();
944
945 arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr();
946 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
947 reductionOp.getContext(),
948 convertArithFastMathFlagsToLLVM(fMFAttr.getValue()));
949
950 Value result;
951 switch (kind) {
952 case vector::CombiningKind::ADD:
953 result = lowerPredicatedReductionWithStartValue<
954 LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp,
955 ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc,
956 maskOp.getMask());
957 break;
958 case vector::CombiningKind::MUL:
959 result = lowerPredicatedReductionWithStartValue<
960 LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp,
961 ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc,
962 maskOp.getMask());
963 break;
964 case vector::CombiningKind::MINUI:
965 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMinOp,
966 ReductionNeutralUIntMax>(
967 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
968 break;
969 case vector::CombiningKind::MINSI:
970 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMinOp,
971 ReductionNeutralSIntMax>(
972 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
973 break;
974 case vector::CombiningKind::MAXUI:
975 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceUMaxOp,
976 ReductionNeutralUIntMin>(
977 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
978 break;
979 case vector::CombiningKind::MAXSI:
980 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceSMaxOp,
981 ReductionNeutralSIntMin>(
982 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
983 break;
984 case vector::CombiningKind::AND:
985 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceAndOp,
986 ReductionNeutralAllOnes>(
987 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
988 break;
989 case vector::CombiningKind::OR:
990 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceOrOp,
991 ReductionNeutralZero>(
992 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
993 break;
994 case vector::CombiningKind::XOR:
995 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceXorOp,
996 ReductionNeutralZero>(
997 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
998 break;
999 case vector::CombiningKind::MINNUMF:
1000 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMinOp,
1001 ReductionNeutralFPMax>(
1002 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1003 break;
1004 case vector::CombiningKind::MAXNUMF:
1005 result = lowerPredicatedReductionWithStartValue<LLVM::VPReduceFMaxOp,
1006 ReductionNeutralFPMin>(
1007 rewriter, loc, llvmType, operand, acc, maskOp.getMask());
1008 break;
1009 case CombiningKind::MAXIMUMF:
1010 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fmaximum,
1011 MaskNeutralFMaximum>(
1012 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1013 break;
1014 case CombiningKind::MINIMUMF:
1015 result = lowerMaskedReductionWithRegular<LLVM::vector_reduce_fminimum,
1016 MaskNeutralFMinimum>(
1017 rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf);
1018 break;
1019 }
1020
1021 // Replace `vector.mask` operation altogether.
1022 rewriter.replaceOp(maskOp, result);
1023 return success();
1024 }
1025};
1026
1027class VectorShuffleOpConversion
1028 : public ConvertOpToLLVMPattern<vector::ShuffleOp> {
1029public:
1030 using ConvertOpToLLVMPattern<vector::ShuffleOp>::ConvertOpToLLVMPattern;
1031
1032 LogicalResult
1033 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
1034 ConversionPatternRewriter &rewriter) const override {
1035 auto loc = shuffleOp->getLoc();
1036 auto v1Type = shuffleOp.getV1VectorType();
1037 auto v2Type = shuffleOp.getV2VectorType();
1038 auto vectorType = shuffleOp.getResultVectorType();
1039 Type llvmType = typeConverter->convertType(vectorType);
1040 ArrayRef<int64_t> mask = shuffleOp.getMask();
1041
1042 // Bail if result type cannot be lowered.
1043 if (!llvmType)
1044 return failure();
1045
1046 // Get rank and dimension sizes.
1047 int64_t rank = vectorType.getRank();
1048#ifndef NDEBUG
1049 bool wellFormed0DCase =
1050 v1Type.getRank() == 0 && v2Type.getRank() == 0 && rank == 1;
1051 bool wellFormedNDCase =
1052 v1Type.getRank() == rank && v2Type.getRank() == rank;
1053 assert((wellFormed0DCase || wellFormedNDCase) && "op is not well-formed");
1054#endif
1055
1056 // For rank 0 and 1, where both operands have *exactly* the same vector
1057 // type, there is direct shuffle support in LLVM. Use it!
1058 if (rank <= 1 && v1Type == v2Type) {
1059 Value llvmShuffleOp = LLVM::ShuffleVectorOp::create(
1060 rewriter, loc, adaptor.getV1(), adaptor.getV2(),
1061 llvm::to_vector_of<int32_t>(mask));
1062 rewriter.replaceOp(shuffleOp, llvmShuffleOp);
1063 return success();
1064 }
1065
1066 // For all other cases, insert the individual values individually.
1067 int64_t v1Dim = v1Type.getDimSize(0);
1068 Type eltType;
1069 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(llvmType))
1070 eltType = arrayType.getElementType();
1071 else
1072 eltType = cast<VectorType>(llvmType).getElementType();
1073 Value insert = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1074 int64_t insPos = 0;
1075 for (int64_t extPos : mask) {
1076 Value value = adaptor.getV1();
1077 if (extPos >= v1Dim) {
1078 extPos -= v1Dim;
1079 value = adaptor.getV2();
1080 }
1081 Value extract = extractOne(rewriter, *getTypeConverter(), loc, value,
1082 eltType, rank, extPos);
1083 insert = insertOne(rewriter, *getTypeConverter(), loc, insert, extract,
1084 llvmType, rank, insPos++);
1085 }
1086 rewriter.replaceOp(shuffleOp, insert);
1087 return success();
1088 }
1089};
1090
1091class VectorExtractOpConversion
1092 : public ConvertOpToLLVMPattern<vector::ExtractOp> {
1093public:
1094 using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
1095
1096 LogicalResult
1097 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
1098 ConversionPatternRewriter &rewriter) const override {
1099 auto loc = extractOp->getLoc();
1100 auto resultType = extractOp.getResult().getType();
1101 auto llvmResultType = typeConverter->convertType(resultType);
1102 // Bail if result type cannot be lowered.
1103 if (!llvmResultType)
1104 return failure();
1105
1106 SmallVector<OpFoldResult> positionVec = getMixedValues(
1107 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1108
1109 // The Vector -> LLVM lowering models N-D vectors as nested aggregates of
1110 // 1-d vectors. This nesting is modeled using arrays. We do this conversion
1111 // from a N-d vector extract to a nested aggregate vector extract in two
1112 // steps:
1113 // - Extract a member from the nested aggregate. The result can be
1114 // a lower rank nested aggregate or a vector (1-D). This is done using
1115 // `llvm.extractvalue`.
1116 // - Extract a scalar out of the vector if needed. This is done using
1117 // `llvm.extractelement`.
1118
1119 // Determine if we need to extract a member out of the aggregate. We
1120 // always need to extract a member if the input rank >= 2.
1121 bool extractsAggregate = extractOp.getSourceVectorType().getRank() >= 2;
1122 // Determine if we need to extract a scalar as the result. We extract
1123 // a scalar if the extract is full rank, i.e., the number of indices is
1124 // equal to source vector rank.
1125 bool extractsScalar = static_cast<int64_t>(positionVec.size()) ==
1126 extractOp.getSourceVectorType().getRank();
1127
1128 // Since the LLVM type converter converts 0-d vectors to 1-d vectors, we
1129 // need to add a position for this change.
1130 if (extractOp.getSourceVectorType().getRank() == 0) {
1131 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1132 positionVec.push_back(rewriter.getZeroAttr(idxType));
1133 }
1134
1135 Value extracted = adaptor.getSource();
1136 if (extractsAggregate) {
1137 ArrayRef<OpFoldResult> position(positionVec);
1138 if (extractsScalar) {
1139 // If we are extracting a scalar from the extracted member, we drop
1140 // the last index, which will be used to extract the scalar out of the
1141 // vector.
1142 position = position.drop_back();
1143 }
1144 // llvm.extractvalue does not support dynamic dimensions.
1145 if (!llvm::all_of(position, llvm::IsaPred<Attribute>)) {
1146 return failure();
1147 }
1148 extracted = LLVM::ExtractValueOp::create(rewriter, loc, extracted,
1149 getAsIntegers(position));
1150 }
1151
1152 if (extractsScalar) {
1153 extracted = LLVM::ExtractElementOp::create(
1154 rewriter, loc, extracted,
1155 getAsLLVMValue(rewriter, loc, positionVec.back()));
1156 }
1157
1158 rewriter.replaceOp(extractOp, extracted);
1159 return success();
1160 }
1161};
1162
1163/// Conversion pattern that turns a vector.fma on a 1-D vector
1164/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
1165/// This does not match vectors of n >= 2 rank.
1166///
1167/// Example:
1168/// ```
1169/// vector.fma %a, %a, %a : vector<8xf32>
1170/// ```
1171/// is converted to:
1172/// ```
1173/// llvm.intr.fmuladd %va, %va, %va:
1174/// (!llvm."<8 x f32>">, !llvm<"<8 x f32>">, !llvm<"<8 x f32>">)
1175/// -> !llvm."<8 x f32>">
1176/// ```
1177class VectorFMAOp1DConversion : public ConvertOpToLLVMPattern<vector::FMAOp> {
1178public:
1179 using ConvertOpToLLVMPattern<vector::FMAOp>::ConvertOpToLLVMPattern;
1180
1181 LogicalResult
1182 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1183 ConversionPatternRewriter &rewriter) const override {
1184 VectorType vType = fmaOp.getVectorType();
1185 if (vType.getRank() > 1)
1186 return failure();
1187
1188 rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
1189 fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
1190 return success();
1191 }
1192};
1193
1194class VectorInsertOpConversion
1195 : public ConvertOpToLLVMPattern<vector::InsertOp> {
1196public:
1197 using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
1198
1199 LogicalResult
1200 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
1201 ConversionPatternRewriter &rewriter) const override {
1202 auto loc = insertOp->getLoc();
1203 auto destVectorType = insertOp.getDestVectorType();
1204 auto llvmResultType = typeConverter->convertType(destVectorType);
1205 // Bail if result type cannot be lowered.
1206 if (!llvmResultType)
1207 return failure();
1208
1209 SmallVector<OpFoldResult> positionVec = getMixedValues(
1210 adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
1211
1212 // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1213 // its explanatory comment about how N-D vectors are converted as nested
1214 // aggregates (llvm.array's) of 1D vectors.
1215 //
1216 // The innermost dimension of the destination vector, when converted to a
1217 // nested aggregate form, will always be a 1D vector.
1218 //
1219 // * If the insertion is happening into the innermost dimension of the
1220 // destination vector:
1221 // - If the destination is a nested aggregate, extract a 1D vector out of
1222 // the aggregate. This can be done using llvm.extractvalue. The
1223 // destination is now guaranteed to be a 1D vector, to which we are
1224 // inserting.
1225 // - Do the insertion into the 1D destination vector, and make the result
1226 // the new source nested aggregate. This can be done using
1227 // llvm.insertelement.
1228 // * Insert the source nested aggregate into the destination nested
1229 // aggregate.
1230
1231 // Determine if we need to extract/insert a 1D vector out of the aggregate.
1232 bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1233 // Determine if we need to insert a scalar into the 1D vector.
1234 bool insertIntoInnermostDim =
1235 static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1236
1237 ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1238 positionVec.begin(),
1239 insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1240 OpFoldResult positionOfScalarWithin1DVector;
1241 if (destVectorType.getRank() == 0) {
1242 // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1243 // need to create a 0 here as the position into the 1D vector.
1244 Type idxType = typeConverter->convertType(rewriter.getIndexType());
1245 positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1246 } else if (insertIntoInnermostDim) {
1247 positionOfScalarWithin1DVector = positionVec.back();
1248 }
1249
1250 // We are going to mutate this 1D vector until it is either the final
1251 // result (in the non-aggregate case) or the value that needs to be
1252 // inserted into the aggregate result.
1253 Value sourceAggregate = adaptor.getValueToStore();
1254 if (insertIntoInnermostDim) {
1255 // Scalar-into-1D-vector case, so we know we will have to create a
1256 // InsertElementOp. The question is into what destination.
1257 if (isNestedAggregate) {
1258 // Aggregate case: the destination for the InsertElementOp needs to be
1259 // extracted from the aggregate.
1260 if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1261 llvm::IsaPred<Attribute>)) {
1262 // llvm.extractvalue does not support dynamic dimensions.
1263 return failure();
1264 }
1265 sourceAggregate = LLVM::ExtractValueOp::create(
1266 rewriter, loc, adaptor.getDest(),
1267 getAsIntegers(positionOf1DVectorWithinAggregate));
1268 } else {
1269 // No-aggregate case. The destination for the InsertElementOp is just
1270 // the insertOp's destination.
1271 sourceAggregate = adaptor.getDest();
1272 }
1273 // Insert the scalar into the 1D vector.
1274 sourceAggregate = LLVM::InsertElementOp::create(
1275 rewriter, loc, sourceAggregate.getType(), sourceAggregate,
1276 adaptor.getValueToStore(),
1277 getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
1278 }
1279
1280 Value result = sourceAggregate;
1281 if (isNestedAggregate) {
1282 result = LLVM::InsertValueOp::create(
1283 rewriter, loc, adaptor.getDest(), sourceAggregate,
1284 getAsIntegers(positionOf1DVectorWithinAggregate));
1285 }
1286
1287 rewriter.replaceOp(insertOp, result);
1288 return success();
1289 }
1290};
1291
1292/// Lower vector.scalable.insert ops to LLVM vector.insert
1293struct VectorScalableInsertOpLowering
1294 : public ConvertOpToLLVMPattern<vector::ScalableInsertOp> {
1295 using ConvertOpToLLVMPattern<
1296 vector::ScalableInsertOp>::ConvertOpToLLVMPattern;
1297
1298 LogicalResult
1299 matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
1300 ConversionPatternRewriter &rewriter) const override {
1301 rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1302 insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
1303 return success();
1304 }
1305};
1306
1307/// Lower vector.scalable.extract ops to LLVM vector.extract
1308struct VectorScalableExtractOpLowering
1309 : public ConvertOpToLLVMPattern<vector::ScalableExtractOp> {
1310 using ConvertOpToLLVMPattern<
1311 vector::ScalableExtractOp>::ConvertOpToLLVMPattern;
1312
1313 LogicalResult
1314 matchAndRewrite(vector::ScalableExtractOp extOp, OpAdaptor adaptor,
1315 ConversionPatternRewriter &rewriter) const override {
1316 rewriter.replaceOpWithNewOp<LLVM::vector_extract>(
1317 extOp, typeConverter->convertType(extOp.getResultVectorType()),
1318 adaptor.getSource(), adaptor.getPos());
1319 return success();
1320 }
1321};
1322
1323/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
1324///
1325/// Example:
1326/// ```
1327/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
1328/// ```
1329/// is rewritten into:
1330/// ```
1331/// %r = vector.broadcast %f0 : f32 to vector<2x4xf32>
1332/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
1333/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
1334/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
1335/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
1336/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
1337/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
1338/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
1339/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
1340/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
1341/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
1342/// // %r3 holds the final value.
1343/// ```
1344class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
1345public:
1346 using Base::Base;
1347
1348 void initialize() {
1349 // This pattern recursively unpacks one dimension at a time. The recursion
1350 // bounded as the rank is strictly decreasing.
1351 setHasBoundedRewriteRecursion();
1352 }
1353
1354 LogicalResult matchAndRewrite(FMAOp op,
1355 PatternRewriter &rewriter) const override {
1356 auto vType = op.getVectorType();
1357 if (vType.getRank() < 2)
1358 return failure();
1359
1360 auto loc = op.getLoc();
1361 auto elemType = vType.getElementType();
1362 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
1363 rewriter.getZeroAttr(elemType));
1364 Value desc = vector::BroadcastOp::create(rewriter, loc, vType, zero);
1365 for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
1366 Value extrLHS = ExtractOp::create(rewriter, loc, op.getLhs(), i);
1367 Value extrRHS = ExtractOp::create(rewriter, loc, op.getRhs(), i);
1368 Value extrACC = ExtractOp::create(rewriter, loc, op.getAcc(), i);
1369 Value fma = FMAOp::create(rewriter, loc, extrLHS, extrRHS, extrACC);
1370 desc = InsertOp::create(rewriter, loc, fma, desc, i);
1371 }
1372 rewriter.replaceOp(op, desc);
1373 return success();
1374 }
1375};
1376
1377/// Returns the strides if the memory underlying `memRefType` has a contiguous
1378/// static layout.
1379static std::optional<SmallVector<int64_t, 4>>
1380computeContiguousStrides(MemRefType memRefType) {
1381 int64_t offset;
1383 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1384 return std::nullopt;
1385 if (!strides.empty() && strides.back() != 1)
1386 return std::nullopt;
1387 // If no layout or identity layout, this is contiguous by definition.
1388 if (memRefType.getLayout().isIdentity())
1389 return strides;
1390
1391 // Otherwise, we must determine contiguity form shapes. This can only ever
1392 // work in static cases because MemRefType is underspecified to represent
1393 // contiguous dynamic shapes in other ways than with just empty/identity
1394 // layout.
1395 auto sizes = memRefType.getShape();
1396 for (int index = 0, e = strides.size() - 1; index < e; ++index) {
1397 if (ShapedType::isDynamic(sizes[index + 1]) ||
1398 ShapedType::isDynamic(strides[index]) ||
1399 ShapedType::isDynamic(strides[index + 1]))
1400 return std::nullopt;
1401 if (strides[index] != strides[index + 1] * sizes[index + 1])
1402 return std::nullopt;
1403 }
1404 return strides;
1405}
1406
1407class VectorTypeCastOpConversion
1408 : public ConvertOpToLLVMPattern<vector::TypeCastOp> {
1409public:
1410 using ConvertOpToLLVMPattern<vector::TypeCastOp>::ConvertOpToLLVMPattern;
1411
1412 LogicalResult
1413 matchAndRewrite(vector::TypeCastOp castOp, OpAdaptor adaptor,
1414 ConversionPatternRewriter &rewriter) const override {
1415 auto loc = castOp->getLoc();
1416 MemRefType sourceMemRefType =
1417 cast<MemRefType>(castOp.getOperand().getType());
1418 MemRefType targetMemRefType = castOp.getType();
1419
1420 // Only static shape casts supported atm.
1421 if (!sourceMemRefType.hasStaticShape() ||
1422 !targetMemRefType.hasStaticShape())
1423 return failure();
1424
1425 auto llvmSourceDescriptorTy =
1426 dyn_cast<LLVM::LLVMStructType>(adaptor.getOperands()[0].getType());
1427 if (!llvmSourceDescriptorTy)
1428 return failure();
1429 MemRefDescriptor sourceMemRef(adaptor.getOperands()[0]);
1430
1431 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1432 typeConverter->convertType(targetMemRefType));
1433 if (!llvmTargetDescriptorTy)
1434 return failure();
1435
1436 // Only contiguous source buffers supported atm.
1437 auto sourceStrides = computeContiguousStrides(sourceMemRefType);
1438 if (!sourceStrides)
1439 return failure();
1440 auto targetStrides = computeContiguousStrides(targetMemRefType);
1441 if (!targetStrides)
1442 return failure();
1443 // Only support static strides for now, regardless of contiguity.
1444 if (llvm::any_of(*targetStrides, ShapedType::isDynamic))
1445 return failure();
1446
1447 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
1448
1449 // Create descriptor.
1450 auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
1451 // Set allocated ptr.
1452 Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
1453 desc.setAllocatedPtr(rewriter, loc, allocated);
1454
1455 // Set aligned ptr.
1456 Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
1457 desc.setAlignedPtr(rewriter, loc, ptr);
1458 // Fill offset 0.
1459 auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0);
1460 auto zero = LLVM::ConstantOp::create(rewriter, loc, int64Ty, attr);
1461 desc.setOffset(rewriter, loc, zero);
1462
1463 // Fill size and stride descriptors in memref.
1464 for (const auto &indexedSize :
1465 llvm::enumerate(targetMemRefType.getShape())) {
1466 int64_t index = indexedSize.index();
1467 auto sizeAttr =
1468 rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value());
1469 auto size = LLVM::ConstantOp::create(rewriter, loc, int64Ty, sizeAttr);
1470 desc.setSize(rewriter, loc, index, size);
1471 auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(),
1472 (*targetStrides)[index]);
1473 auto stride =
1474 LLVM::ConstantOp::create(rewriter, loc, int64Ty, strideAttr);
1475 desc.setStride(rewriter, loc, index, stride);
1476 }
1477
1478 rewriter.replaceOp(castOp, {desc});
1479 return success();
1480 }
1481};
1482
1483/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
1484/// Non-scalable versions of this operation are handled in Vector Transforms.
1485class VectorCreateMaskOpConversion
1486 : public OpConversionPattern<vector::CreateMaskOp> {
1487public:
1488 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1489 bool enableIndexOpt)
1490 : OpConversionPattern<vector::CreateMaskOp>(context),
1491 force32BitVectorIndices(enableIndexOpt) {}
1492
1493 LogicalResult
1494 matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
1495 ConversionPatternRewriter &rewriter) const override {
1496 auto dstType = op.getType();
1497 if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
1498 return failure();
1499 IntegerType idxType =
1500 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1501 auto loc = op->getLoc();
1502 Value indices = LLVM::StepVectorOp::create(
1503 rewriter, loc,
1504 LLVM::getVectorType(idxType, dstType.getShape()[0],
1505 /*isScalable=*/true));
1506 auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
1507 adaptor.getOperands()[0]);
1508 Value bounds = BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1509 Value comp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1510 indices, bounds);
1511 rewriter.replaceOp(op, comp);
1512 return success();
1513 }
1514
1515private:
1516 const bool force32BitVectorIndices;
1517};
1518
1519class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
1520 SymbolTableCollection *symbolTables = nullptr;
1521
1522public:
1523 explicit VectorPrintOpConversion(
1524 const LLVMTypeConverter &typeConverter,
1525 SymbolTableCollection *symbolTables = nullptr)
1526 : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
1527 symbolTables(symbolTables) {}
1528
1529 // Lowering implementation that relies on a small runtime support library,
1530 // which only needs to provide a few printing methods (single value for all
1531 // data types, opening/closing bracket, comma, newline). The lowering splits
1532 // the vector into elementary printing operations. The advantage of this
1533 // approach is that the library can remain unaware of all low-level
1534 // implementation details of vectors while still supporting output of any
1535 // shaped and dimensioned vector.
1536 //
1537 // Note: This lowering only handles scalars, n-D vectors are broken into
1538 // printing scalars in loops in VectorToSCF.
1539 //
1540 // TODO: rely solely on libc in future? something else?
1541 //
1542 LogicalResult
1543 matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
1544 ConversionPatternRewriter &rewriter) const override {
1545 auto parent = printOp->getParentOfType<ModuleOp>();
1546 if (!parent)
1547 return failure();
1548
1549 auto loc = printOp->getLoc();
1550
1551 if (auto value = adaptor.getSource()) {
1552 Type printType = printOp.getPrintType();
1553 if (isa<VectorType>(printType)) {
1554 // Vectors should be broken into elementary print ops in VectorToSCF.
1555 return failure();
1556 }
1557 if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
1558 return failure();
1559 }
1560
1561 auto punct = printOp.getPunctuation();
1562 if (auto stringLiteral = printOp.getStringLiteral()) {
1563 auto createResult =
1564 LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1565 *stringLiteral, *getTypeConverter(),
1566 /*addNewline=*/false);
1567 if (createResult.failed())
1568 return failure();
1569
1570 } else if (punct != PrintPunctuation::NoPunctuation) {
1571 FailureOr<LLVM::LLVMFuncOp> op = [&]() {
1572 switch (punct) {
1573 case PrintPunctuation::Close:
1574 return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
1575 symbolTables);
1576 case PrintPunctuation::Open:
1577 return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
1578 symbolTables);
1579 case PrintPunctuation::Comma:
1580 return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
1581 symbolTables);
1582 case PrintPunctuation::NewLine:
1583 return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
1584 symbolTables);
1585 default:
1586 llvm_unreachable("unexpected punctuation");
1587 }
1588 }();
1589 if (failed(op))
1590 return failure();
1591 emitCall(rewriter, printOp->getLoc(), op.value());
1592 }
1593
1594 rewriter.eraseOp(printOp);
1595 return success();
1596 }
1597
1598private:
1599 enum class PrintConversion {
1600 // clang-format off
1601 None,
1602 ZeroExt64,
1603 SignExt64,
1604 Bitcast16
1605 // clang-format on
1606 };
1607
1608 LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
1609 ModuleOp parent, Location loc, Type printType,
1610 Value value) const {
1611 if (typeConverter->convertType(printType) == nullptr)
1612 return failure();
1613
1614 // Make sure element type has runtime support.
1615 PrintConversion conversion = PrintConversion::None;
1616 FailureOr<Operation *> printer;
1617 if (printType.isF32()) {
1618 printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
1619 } else if (printType.isF64()) {
1620 printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
1621 } else if (printType.isF16()) {
1622 conversion = PrintConversion::Bitcast16; // bits!
1623 printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
1624 } else if (printType.isBF16()) {
1625 conversion = PrintConversion::Bitcast16; // bits!
1626 printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
1627 } else if (printType.isIndex()) {
1628 printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1629 } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
1630 // Integers need a zero or sign extension on the operand
1631 // (depending on the source type) as well as a signed or
1632 // unsigned print method. Up to 64-bit is supported.
1633 unsigned width = intTy.getWidth();
1634 if (intTy.isUnsigned()) {
1635 if (width <= 64) {
1636 if (width < 64)
1637 conversion = PrintConversion::ZeroExt64;
1638 printer =
1639 LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
1640 } else {
1641 return failure();
1642 }
1643 } else {
1644 assert(intTy.isSignless() || intTy.isSigned());
1645 if (width <= 64) {
1646 // Note that we *always* zero extend booleans (1-bit integers),
1647 // so that true/false is printed as 1/0 rather than -1/0.
1648 if (width == 1)
1649 conversion = PrintConversion::ZeroExt64;
1650 else if (width < 64)
1651 conversion = PrintConversion::SignExt64;
1652 printer =
1653 LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
1654 } else {
1655 return failure();
1656 }
1657 }
1658 } else if (auto floatTy = dyn_cast<FloatType>(printType)) {
1659 // Print other floating-point types using the APFloat runtime library.
1660 int32_t sem =
1661 llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1662 Value semValue = LLVM::ConstantOp::create(
1663 rewriter, loc, rewriter.getI32Type(),
1664 rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1665 Value floatBits =
1666 LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1667 printer =
1668 LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
1669 emitCall(rewriter, loc, printer.value(),
1670 ValueRange({semValue, floatBits}));
1671 return success();
1672 } else {
1673 return failure();
1674 }
1675 if (failed(printer))
1676 return failure();
1677
1678 switch (conversion) {
1679 case PrintConversion::ZeroExt64:
1680 value = arith::ExtUIOp::create(
1681 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1682 break;
1683 case PrintConversion::SignExt64:
1684 value = arith::ExtSIOp::create(
1685 rewriter, loc, IntegerType::get(rewriter.getContext(), 64), value);
1686 break;
1687 case PrintConversion::Bitcast16:
1688 value = LLVM::BitcastOp::create(
1689 rewriter, loc, IntegerType::get(rewriter.getContext(), 16), value);
1690 break;
1691 case PrintConversion::None:
1692 break;
1693 }
1694 emitCall(rewriter, loc, printer.value(), value);
1695 return success();
1696 }
1697
1698 // Helper to emit a call.
1699 static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
1700 Operation *ref, ValueRange params = ValueRange()) {
1701 LLVM::CallOp::create(rewriter, loc, TypeRange(), SymbolRefAttr::get(ref),
1702 params);
1703 }
1704};
1705
1706/// A broadcast of a scalar is lowered to an insertelement + a shufflevector
1707/// operation. Only broadcasts to 0-d and 1-d vectors are lowered by this
1708/// pattern, the higher rank cases are handled by another pattern.
1709struct VectorBroadcastScalarToLowRankLowering
1710 : public ConvertOpToLLVMPattern<vector::BroadcastOp> {
1711 using ConvertOpToLLVMPattern<vector::BroadcastOp>::ConvertOpToLLVMPattern;
1712
1713 LogicalResult
1714 matchAndRewrite(vector::BroadcastOp broadcast, OpAdaptor adaptor,
1715 ConversionPatternRewriter &rewriter) const override {
1716 if (isa<VectorType>(broadcast.getSourceType()))
1717 return rewriter.notifyMatchFailure(
1718 broadcast, "broadcast from vector type not handled");
1719
1720 VectorType resultType = broadcast.getType();
1721 if (resultType.getRank() > 1)
1722 return rewriter.notifyMatchFailure(broadcast,
1723 "broadcast to 2+-d handled elsewhere");
1724
1725 // First insert it into a poison vector so we can shuffle it.
1726 auto vectorType = typeConverter->convertType(broadcast.getType());
1727 Value poison =
1728 LLVM::PoisonOp::create(rewriter, broadcast.getLoc(), vectorType);
1729 auto zero = LLVM::ConstantOp::create(
1730 rewriter, broadcast.getLoc(),
1731 typeConverter->convertType(rewriter.getIntegerType(32)),
1732 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1733
1734 // For 0-d vector, we simply do `insertelement`.
1735 if (resultType.getRank() == 0) {
1736 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
1737 broadcast, vectorType, poison, adaptor.getSource(), zero);
1738 return success();
1739 }
1740
1741 auto v =
1742 LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
1743 poison, adaptor.getSource(), zero);
1744
1745 // For 1-d vector, we additionally do a `shufflevector`.
1746 int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
1747 SmallVector<int32_t> zeroValues(width, 0);
1748
1749 // Shuffle the value across the desired number of elements.
1750 auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
1751 broadcast.getLoc(), v, poison, zeroValues);
1752 rewriter.replaceOp(broadcast, shuffle);
1753 return success();
1754 }
1755};
1756
1757/// The broadcast of a scalar is lowered to an insertelement + a shufflevector
1758/// operation. Only broadcasts to 2+-d vector result types are lowered by this
1759/// pattern, the 1-d case is handled by another pattern. Broadcasts from vectors
1760/// are not converted to LLVM, only broadcasts from scalars are.
1761struct VectorBroadcastScalarToNdLowering
1762 : public ConvertOpToLLVMPattern<BroadcastOp> {
1763 using ConvertOpToLLVMPattern<BroadcastOp>::ConvertOpToLLVMPattern;
1764
1765 LogicalResult
1766 matchAndRewrite(BroadcastOp broadcast, OpAdaptor adaptor,
1767 ConversionPatternRewriter &rewriter) const override {
1768 if (isa<VectorType>(broadcast.getSourceType()))
1769 return rewriter.notifyMatchFailure(
1770 broadcast, "broadcast from vector type not handled");
1771
1772 VectorType resultType = broadcast.getType();
1773 if (resultType.getRank() <= 1)
1774 return rewriter.notifyMatchFailure(
1775 broadcast, "broadcast to 1-d or 0-d handled elsewhere");
1776
1777 // First insert it into an undef vector so we can shuffle it.
1778 auto loc = broadcast.getLoc();
1779 auto vectorTypeInfo =
1780 LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter());
1781 auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy;
1782 auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy;
1783 if (!llvmNDVectorTy || !llvm1DVectorTy)
1784 return failure();
1785
1786 // Construct returned value.
1787 Value desc = LLVM::PoisonOp::create(rewriter, loc, llvmNDVectorTy);
1788
1789 // Construct a 1-D vector with the broadcasted value that we insert in all
1790 // the places within the returned descriptor.
1791 Value vdesc = LLVM::PoisonOp::create(rewriter, loc, llvm1DVectorTy);
1792 auto zero = LLVM::ConstantOp::create(
1793 rewriter, loc, typeConverter->convertType(rewriter.getIntegerType(32)),
1794 rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1795 Value v = LLVM::InsertElementOp::create(rewriter, loc, llvm1DVectorTy,
1796 vdesc, adaptor.getSource(), zero);
1797
1798 // Shuffle the value across the desired number of elements.
1799 int64_t width = resultType.getDimSize(resultType.getRank() - 1);
1800 SmallVector<int32_t> zeroValues(width, 0);
1801 v = LLVM::ShuffleVectorOp::create(rewriter, loc, v, v, zeroValues);
1802
1803 // Iterate of linear index, convert to coords space and insert broadcasted
1804 // 1-D vector in each position.
1805 nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
1806 desc = LLVM::InsertValueOp::create(rewriter, loc, desc, v, position);
1807 });
1808 rewriter.replaceOp(broadcast, desc);
1809 return success();
1810 }
1811};
1812
1813/// Conversion pattern for a `vector.interleave`.
1814/// This supports fixed-sized vectors and scalable vectors.
1815struct VectorInterleaveOpLowering
1816 : public ConvertOpToLLVMPattern<vector::InterleaveOp> {
1818
1819 LogicalResult
1820 matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
1821 ConversionPatternRewriter &rewriter) const override {
1822 VectorType resultType = interleaveOp.getResultVectorType();
1823 // n-D interleaves should have been lowered already.
1824 if (resultType.getRank() != 1)
1825 return rewriter.notifyMatchFailure(interleaveOp,
1826 "InterleaveOp not rank 1");
1827 // If the result is rank 1, then this directly maps to LLVM.
1828 if (resultType.isScalable()) {
1829 rewriter.replaceOpWithNewOp<LLVM::vector_interleave2>(
1830 interleaveOp, typeConverter->convertType(resultType),
1831 adaptor.getLhs(), adaptor.getRhs());
1832 return success();
1833 }
1834 // Lower fixed-size interleaves to a shufflevector. While the
1835 // vector.interleave2 intrinsic supports fixed and scalable vectors, the
1836 // langref still recommends fixed-vectors use shufflevector, see:
1837 // https://llvm.org/docs/LangRef.html#id876.
1838 int64_t resultVectorSize = resultType.getNumElements();
1839 SmallVector<int32_t> interleaveShuffleMask;
1840 interleaveShuffleMask.reserve(resultVectorSize);
1841 for (int i = 0, end = resultVectorSize / 2; i < end; ++i) {
1842 interleaveShuffleMask.push_back(i);
1843 interleaveShuffleMask.push_back((resultVectorSize / 2) + i);
1844 }
1845 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1846 interleaveOp, adaptor.getLhs(), adaptor.getRhs(),
1847 interleaveShuffleMask);
1848 return success();
1849 }
1850};
1851
1852/// Conversion pattern for a `vector.deinterleave`.
1853/// This supports fixed-sized vectors and scalable vectors.
1854struct VectorDeinterleaveOpLowering
1855 : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
1857
1858 LogicalResult
1859 matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
1860 ConversionPatternRewriter &rewriter) const override {
1861 VectorType resultType = deinterleaveOp.getResultVectorType();
1862 VectorType sourceType = deinterleaveOp.getSourceVectorType();
1863 auto loc = deinterleaveOp.getLoc();
1864
1865 // Note: n-D deinterleave operations should be lowered to the 1-D before
1866 // converting to LLVM.
1867 if (resultType.getRank() != 1)
1868 return rewriter.notifyMatchFailure(deinterleaveOp,
1869 "DeinterleaveOp not rank 1");
1870
1871 if (resultType.isScalable()) {
1872 const auto *llvmTypeConverter = this->getTypeConverter();
1873 auto deinterleaveResults = deinterleaveOp.getResultTypes();
1874 auto packedOpResults =
1875 llvmTypeConverter->packOperationResults(deinterleaveResults);
1876 auto intrinsic = LLVM::vector_deinterleave2::create(
1877 rewriter, loc, packedOpResults, adaptor.getSource());
1878
1879 auto evenResult = LLVM::ExtractValueOp::create(
1880 rewriter, loc, intrinsic->getResult(0), 0);
1881 auto oddResult = LLVM::ExtractValueOp::create(rewriter, loc,
1882 intrinsic->getResult(0), 1);
1883
1884 rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
1885 return success();
1886 }
1887 // Lower fixed-size deinterleave to two shufflevectors. While the
1888 // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
1889 // langref still recommends fixed-vectors use shufflevector, see:
1890 // https://llvm.org/docs/LangRef.html#id889.
1891 int64_t resultVectorSize = resultType.getNumElements();
1892 SmallVector<int32_t> evenShuffleMask;
1893 SmallVector<int32_t> oddShuffleMask;
1894
1895 evenShuffleMask.reserve(resultVectorSize);
1896 oddShuffleMask.reserve(resultVectorSize);
1897
1898 for (int i = 0; i < sourceType.getNumElements(); ++i) {
1899 if (i % 2 == 0)
1900 evenShuffleMask.push_back(i);
1901 else
1902 oddShuffleMask.push_back(i);
1903 }
1904
1905 auto poison = LLVM::PoisonOp::create(rewriter, loc, sourceType);
1906 auto evenShuffle = LLVM::ShuffleVectorOp::create(
1907 rewriter, loc, adaptor.getSource(), poison, evenShuffleMask);
1908 auto oddShuffle = LLVM::ShuffleVectorOp::create(
1909 rewriter, loc, adaptor.getSource(), poison, oddShuffleMask);
1910
1911 rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
1912 return success();
1913 }
1914};
1915
1916/// Conversion pattern for a `vector.from_elements`.
1917struct VectorFromElementsLowering
1918 : public ConvertOpToLLVMPattern<vector::FromElementsOp> {
1920
1921 LogicalResult
1922 matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
1923 ConversionPatternRewriter &rewriter) const override {
1924 Location loc = fromElementsOp.getLoc();
1925 VectorType vectorType = fromElementsOp.getType();
1926 // Only support 1-D vectors. Multi-dimensional vectors should have been
1927 // transformed to 1-D vectors by the vector-to-vector transformations before
1928 // this.
1929 if (vectorType.getRank() > 1)
1930 return rewriter.notifyMatchFailure(fromElementsOp,
1931 "rank > 1 vectors are not supported");
1932 Type llvmType = typeConverter->convertType(vectorType);
1933 Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
1934 Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
1935 for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
1936 auto constIdx =
1937 LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
1938 result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
1939 val, constIdx);
1940 }
1941 rewriter.replaceOp(fromElementsOp, result);
1942 return success();
1943 }
1944};
1945
1946/// Conversion pattern for a `vector.to_elements`.
1947struct VectorToElementsLowering
1948 : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1950
1951 LogicalResult
1952 matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1953 ConversionPatternRewriter &rewriter) const override {
1954 Location loc = toElementsOp.getLoc();
1955 auto idxType = typeConverter->convertType(rewriter.getIndexType());
1956 Value source = adaptor.getSource();
1957
1958 SmallVector<Value> results(toElementsOp->getNumResults());
1959 for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1960 // Create an extractelement operation only for results that are not dead.
1961 if (element.use_empty())
1962 continue;
1963
1964 auto constIdx = LLVM::ConstantOp::create(
1965 rewriter, loc, idxType, rewriter.getIntegerAttr(idxType, idx));
1966 auto llvmType = typeConverter->convertType(element.getType());
1967
1968 Value result = LLVM::ExtractElementOp::create(rewriter, loc, llvmType,
1969 source, constIdx);
1970 results[idx] = result;
1971 }
1972
1973 rewriter.replaceOp(toElementsOp, results);
1974 return success();
1975 }
1976};
1977
1978/// Conversion pattern for vector.step.
1979struct VectorScalableStepOpLowering
1980 : public ConvertOpToLLVMPattern<vector::StepOp> {
1982
1983 LogicalResult
1984 matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
1985 ConversionPatternRewriter &rewriter) const override {
1986 auto resultType = cast<VectorType>(stepOp.getType());
1987 if (!resultType.isScalable()) {
1988 return failure();
1989 }
1990 Type llvmType = typeConverter->convertType(stepOp.getType());
1991 rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
1992 return success();
1993 }
1994};
1995
1996/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
1997/// semantics to:
1998/// ```
1999/// %flattened_a = vector.shape_cast %a
2000/// %flattened_b = vector.shape_cast %b
2001/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
2002/// %d = vector.shape_cast %%flattened_d
2003/// %e = add %c, %d
2004/// ```
2005/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
2006class ContractionOpToMatmulOpLowering
2007 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
2008public:
2009 using MaskableOpRewritePattern::MaskableOpRewritePattern;
2010
2011 ContractionOpToMatmulOpLowering(MLIRContext *context,
2012 PatternBenefit benefit = 100)
2013 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
2014
2015 FailureOr<Value>
2016 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
2017 PatternRewriter &rewriter) const override;
2018};
2019
2020/// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul
2021/// semantics directly into `llvm.intr.matrix.multiply`:
2022/// BEFORE:
2023/// ```mlir
2024/// %res = vector.contract #matmat_trait %lhs, %rhs, %acc
2025/// : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
2026/// ```
2027///
2028/// AFTER:
2029/// ```mlir
2030/// %lhs = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
2031/// %rhs = vector.shape_cast %arg1 : vector<4x3xf32> to vector<12xf32>
2032/// %matmul = llvm.intr.matrix.multiply %lhs, %rhs
2033/// %res = arith.addf %acc, %matmul : vector<2x3xf32>
2034/// ```
2035//
2036/// Scalable vectors are not supported.
2037FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
2038 vector::ContractionOp op, MaskingOpInterface maskOp,
2039 PatternRewriter &rew) const {
2040 // TODO: Support vector.mask.
2041 if (maskOp)
2042 return failure();
2043
2044 auto iteratorTypes = op.getIteratorTypes().getValue();
2045 if (!isParallelIterator(iteratorTypes[0]) ||
2046 !isParallelIterator(iteratorTypes[1]) ||
2047 !isReductionIterator(iteratorTypes[2]))
2048 return failure();
2049
2050 Type opResType = op.getType();
2051 VectorType vecType = dyn_cast<VectorType>(opResType);
2052 if (vecType && vecType.isScalable()) {
2053 // Note - this is sufficient to reject all cases with scalable vectors.
2054 return failure();
2055 }
2056
2057 Type elementType = op.getLhsType().getElementType();
2058 if (!elementType.isIntOrFloat())
2059 return failure();
2060
2061 Type dstElementType = vecType ? vecType.getElementType() : opResType;
2062 if (elementType != dstElementType)
2063 return failure();
2064
2065 // Perform lhs + rhs transpositions to conform to matmul row-major semantics.
2066 // Bail out if the contraction cannot be put in this form.
2067 MLIRContext *ctx = op.getContext();
2068 Location loc = op.getLoc();
2069 AffineExpr m, n, k;
2070 bindDims(rew.getContext(), m, n, k);
2071 // LHS must be A(m, k) or A(k, m).
2072 Value lhs = op.getLhs();
2073 auto lhsMap = op.getIndexingMapsArray()[0];
2074 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx))
2075 lhs = vector::TransposeOp::create(rew, loc, lhs, ArrayRef<int64_t>{1, 0});
2076 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx))
2077 return failure();
2078
2079 // RHS must be B(k, n) or B(n, k).
2080 Value rhs = op.getRhs();
2081 auto rhsMap = op.getIndexingMapsArray()[1];
2082 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx))
2083 rhs = vector::TransposeOp::create(rew, loc, rhs, ArrayRef<int64_t>{1, 0});
2084 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx))
2085 return failure();
2086
2087 // At this point lhs and rhs are in row-major.
2088 VectorType lhsType = cast<VectorType>(lhs.getType());
2089 VectorType rhsType = cast<VectorType>(rhs.getType());
2090 int64_t lhsRows = lhsType.getDimSize(0);
2091 int64_t lhsColumns = lhsType.getDimSize(1);
2092 int64_t rhsColumns = rhsType.getDimSize(1);
2093
2094 Type flattenedLHSType =
2095 VectorType::get(lhsType.getNumElements(), lhsType.getElementType());
2096 lhs = vector::ShapeCastOp::create(rew, loc, flattenedLHSType, lhs);
2097
2098 Type flattenedRHSType =
2099 VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
2100 rhs = vector::ShapeCastOp::create(rew, loc, flattenedRHSType, rhs);
2101
2102 Value mul = LLVM::MatrixMultiplyOp::create(
2103 rew, loc,
2104 VectorType::get(lhsRows * rhsColumns,
2105 cast<VectorType>(lhs.getType()).getElementType()),
2106 lhs, rhs, lhsRows, lhsColumns, rhsColumns);
2107
2108 mul = vector::ShapeCastOp::create(
2109 rew, loc,
2110 VectorType::get({lhsRows, rhsColumns},
2111 getElementTypeOrSelf(op.getAcc().getType())),
2112 mul);
2113
2114 // ACC must be C(m, n) or C(n, m).
2115 auto accMap = op.getIndexingMapsArray()[2];
2116 if (accMap == AffineMap::get(3, 0, {n, m}, ctx))
2117 mul = vector::TransposeOp::create(rew, loc, mul, ArrayRef<int64_t>{1, 0});
2118 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx))
2119 llvm_unreachable("invalid contraction semantics");
2120
2121 Value res = isa<IntegerType>(elementType)
2122 ? static_cast<Value>(
2123 arith::AddIOp::create(rew, loc, op.getAcc(), mul))
2124 : static_cast<Value>(
2125 arith::AddFOp::create(rew, loc, op.getAcc(), mul));
2126
2127 return res;
2128}
2129
2130/// Lowers vector.transpose directly to llvm.intr.matrix.transpose
2131///
2132/// BEFORE:
2133/// ```mlir
2134/// %tr = vector.transpose %vec, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
2135/// ```
2136/// AFTER:
2137/// ```mlir
2138/// %vec_cs = vector.shape_cast %vec : vector<2x4xf32> to vector<8xf32>
2139/// %tr = llvm.intr.matrix.transpose %vec_sc
2140/// {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
2141/// %res = vector.shape_cast %tr : vector<8xf32> to vector<4x2xf32>
2142/// ```
2143class TransposeOpToMatrixTransposeOpLowering
2144 : public OpRewritePattern<vector::TransposeOp> {
2145public:
2146 using Base::Base;
2147
2148 LogicalResult matchAndRewrite(vector::TransposeOp op,
2149 PatternRewriter &rewriter) const override {
2150 auto loc = op.getLoc();
2151
2152 Value input = op.getVector();
2153 VectorType inputType = op.getSourceVectorType();
2154 VectorType resType = op.getResultVectorType();
2155
2156 if (inputType.isScalable())
2157 return rewriter.notifyMatchFailure(
2158 op, "This lowering does not support scalable vectors");
2159
2160 // Set up convenience transposition table.
2161 ArrayRef<int64_t> transp = op.getPermutation();
2162
2163 if (resType.getRank() != 2 || transp[0] != 1 || transp[1] != 0) {
2164 return failure();
2165 }
2166
2167 Type flattenedType =
2168 VectorType::get(resType.getNumElements(), resType.getElementType());
2169 auto matrix =
2170 vector::ShapeCastOp::create(rewriter, loc, flattenedType, input);
2171 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
2172 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
2173 Value trans = LLVM::MatrixTransposeOp::create(rewriter, loc, flattenedType,
2174 matrix, rows, columns);
2175 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
2176 return success();
2177 }
2178};
2179
2180} // namespace
2181
2184 patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
2185}
2186
2189 patterns.add<ContractionOpToMatmulOpLowering>(patterns.getContext(), benefit);
2190}
2191
2194 patterns.add<TransposeOpToMatrixTransposeOpLowering>(patterns.getContext(),
2195 benefit);
2196}
2197
2198/// Populate the given list with patterns that convert from Vector to LLVM.
2200 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
2201 bool reassociateFPReductions, bool force32BitVectorIndices,
2202 bool useVectorAlignment) {
2203 // This function populates only ConversionPatterns, not RewritePatterns.
2204 MLIRContext *ctx = converter.getDialect()->getContext();
2205 patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
2206 patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
2207 patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
2208 VectorLoadStoreConversion<vector::MaskedLoadOp>,
2209 VectorLoadStoreConversion<vector::StoreOp>,
2210 VectorLoadStoreConversion<vector::MaskedStoreOp>,
2211 VectorGatherOpConversion, VectorScatterOpConversion>(
2212 converter, useVectorAlignment);
2213 patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
2214 VectorExtractOpConversion, VectorFMAOp1DConversion,
2215 VectorInsertOpConversion, VectorPrintOpConversion,
2216 VectorTypeCastOpConversion, VectorScaleOpConversion,
2217 VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
2218 VectorBroadcastScalarToLowRankLowering,
2219 VectorBroadcastScalarToNdLowering,
2220 VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
2221 MaskedReductionOpConversion, VectorInterleaveOpLowering,
2222 VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2223 VectorToElementsLowering, VectorScalableStepOpLowering>(
2224 converter);
2225}
2226
2227namespace {
2228struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2230 void loadDependentDialects(MLIRContext *context) const final {
2231 context->loadDialect<LLVM::LLVMDialect>();
2232 }
2233
2234 /// Hook for derived dialect interface to provide conversion patterns
2235 /// and mark dialect legal for the conversion target.
2236 void populateConvertToLLVMConversionPatterns(
2237 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2238 RewritePatternSet &patterns) const final {
2240 }
2241};
2242} // namespace
2243
2245 DialectRegistry &registry) {
2246 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
2247 dialect->addInterfaces<VectorToLLVMDialectInterface>();
2248 });
2249}
return success()
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, const LLVMTypeConverter &typeConverter, MemRefType memRefType, Value llvmMemref, Value base, Value index, VectorType vectorType)
LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, MemRefType memrefType, unsigned &align, bool useVectorAlignment)
LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter, VectorType vectorType, unsigned &align)
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, MemRefType memrefType, unsigned &align)
static Value extractOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos)
static Value insertOne(ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos)
static Value getAsLLVMValue(OpBuilder &builder, Location loc, OpFoldResult foldResult)
Convert foldResult into a Value.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter)
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
lhs
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition Unit.cpp:18
#define mul(a, b)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
const llvm::DataLayout & getDataLayout() const
Returns the data layout to use during and after conversion.
FailureOr< unsigned > getMemRefAddressSpace(BaseMemRefType type) const
Return the LLVM address space corresponding to the memory space of the memref type type or failure if...
LLVM::LLVMDialect * getDialect() const
Returns the LLVM dialect.
Utility class to translate MLIR LLVM dialect types to LLVM IR.
Definition TypeToLLVM.h:39
unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout)
Returns the preferred alignment for the type given the data layout.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
LLVM::LLVMPointerType getElementPtrType()
Returns the (LLVM) pointer type this descriptor contains.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition Pattern.h:298
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Definition Pattern.cpp:471
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline=true, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Generate IR that prints the given string to stdout.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void registerConvertVectorToLLVMInterface(DialectRegistry &registry)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:119
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.