MLIR 23.0.0git
ConvertVectorToLLVMPass.cpp
Go to the documentation of this file.
1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
25#include "mlir/Pass/Pass.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34using namespace mlir::vector;
35
36namespace {
37struct ConvertVectorToLLVMPass
38 : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
39
40 using Base::Base;
41
42 // Override explicitly to allow conditional dialect dependence.
43 void getDependentDialects(DialectRegistry &registry) const override {
44 registry.insert<LLVM::LLVMDialect>();
45 registry.insert<arith::ArithDialect>();
46 registry.insert<memref::MemRefDialect>();
47 registry.insert<tensor::TensorDialect>();
48 if (armNeon)
49 registry.insert<arm_neon::ArmNeonDialect>();
50 if (armSVE)
51 registry.insert<arm_sve::ArmSVEDialect>();
52 if (x86)
53 registry.insert<x86::X86Dialect>();
54 }
55 void runOnOperation() override;
56};
57} // namespace
58
59void ConvertVectorToLLVMPass::runOnOperation() {
60 // Perform progressive lowering of operations on slices and all contraction
61 // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
62 // applies folding and DCE.
63 {
64 RewritePatternSet patterns(&getContext());
68 populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
69 if (vectorContractLowering == vector::VectorContractLowering::LLVMIntr) {
70 // This pattern creates a dependency on the LLVM dialect, hence we don't
71 // include it in `populateVectorContractLoweringPatterns` that is part of
72 // the Vector dialect (and should not depend on LLVM).
74 }
78 populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
79 if (vectorTransposeLowering == vector::VectorTransposeLowering::LLVMIntr) {
80 // This pattern creates a dependency on the LLVM dialect, hence we don't
81 // include it in `populateVectorTransposeLoweringPatterns` that is part of
82 // the Vector dialect (and should not depend on LLVM).
84 }
85 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
86 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
87 populateVectorMaskMaterializationPatterns(patterns,
88 force32BitVectorIndices);
89 populateVectorInsertExtractStridedSliceTransforms(patterns);
93 populateVectorFromElementsUnrollPatterns(patterns);
94 populateVectorToElementsUnrollPatterns(patterns);
95 if (armI8MM) {
96 if (armNeon)
98 if (armSVE)
100 }
101 if (armBF16) {
102 if (armNeon)
104 if (armSVE)
106 }
107 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
108 }
109
110 // Convert to the LLVM IR dialect.
111 LowerToLLVMOptions options(&getContext());
112 LLVMTypeConverter converter(&getContext(), options);
113 RewritePatternSet patterns(&getContext());
116 converter, patterns, reassociateFPReductions, force32BitVectorIndices,
117 useVectorAlignment);
118
119 // Architecture specific augmentations.
120 LLVMConversionTarget target(getContext());
121 target.addLegalDialect<arith::ArithDialect>();
122 target.addLegalDialect<memref::MemRefDialect>();
123 target.addLegalOp<UnrealizedConversionCastOp>();
124
125 if (armNeon) {
126 // TODO: we may or may not want to include in-dialect lowering to
127 // LLVM-compatible operations here. So far, all operations in the dialect
128 // can be translated to LLVM IR so there is no conversion necessary.
129 target.addLegalDialect<arm_neon::ArmNeonDialect>();
130 }
131 if (armSVE) {
134 }
135 if (x86) {
137 populateX86LegalizeForLLVMExportPatterns(converter, patterns);
138 }
139
140 if (failed(
141 applyPartialConversion(getOperation(), target, std::move(patterns))))
142 signalPassFailure();
143}
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns)
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void configureX86LegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86 ops to ops that map to LLVM intrinsics.
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns)
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86 ops to ops that map to LLVM intrinsics.
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns)
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.