MLIR  22.0.0git
ConvertVectorToLLVMPass.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
27 #include "mlir/Pass/Pass.h"
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::vector;
37 
38 namespace {
39 struct ConvertVectorToLLVMPass
40  : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
41 
42  using Base::Base;
43 
44  // Override explicitly to allow conditional dialect dependence.
45  void getDependentDialects(DialectRegistry &registry) const override {
46  registry.insert<LLVM::LLVMDialect>();
47  registry.insert<arith::ArithDialect>();
48  registry.insert<memref::MemRefDialect>();
49  registry.insert<tensor::TensorDialect>();
50  if (armNeon)
51  registry.insert<arm_neon::ArmNeonDialect>();
52  if (armSVE)
53  registry.insert<arm_sve::ArmSVEDialect>();
54  if (amx)
55  registry.insert<amx::AMXDialect>();
56  if (x86Vector)
57  registry.insert<x86vector::X86VectorDialect>();
58  }
59  void runOnOperation() override;
60 };
61 } // namespace
62 
63 void ConvertVectorToLLVMPass::runOnOperation() {
64  // Perform progressive lowering of operations on slices and all contraction
65  // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
66  // applies folding and DCE.
67  {
72  populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
73  if (vectorContractLowering == vector::VectorContractLowering::LLVMIntr) {
74  // This pattern creates a dependency on the LLVM dialect, hence we don't
75  // include it in `populateVectorContractLoweringPatterns` that is part of
76  // the Vector dialect (and should not depend on LLVM).
78  }
82  populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
83  if (vectorTransposeLowering == vector::VectorTransposeLowering::LLVMIntr) {
84  // This pattern creates a dependency on the LLVM dialect, hence we don't
85  // include it in `populateVectorTransposeLoweringPatterns` that is part of
86  // the Vector dialect (and should not depend on LLVM).
88  }
89  // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
90  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
91  populateVectorMaskMaterializationPatterns(patterns,
92  force32BitVectorIndices);
93  populateVectorInsertExtractStridedSliceTransforms(patterns);
97  populateVectorFromElementsUnrollPatterns(patterns);
98  populateVectorToElementsUnrollPatterns(patterns);
99  if (armI8MM) {
100  if (armNeon)
102  if (armSVE)
104  }
105  if (armBF16) {
106  if (armNeon)
108  if (armSVE)
110  }
111  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
112  }
113 
114  // Convert to the LLVM IR dialect.
116  LLVMTypeConverter converter(&getContext(), options);
120  converter, patterns, reassociateFPReductions, force32BitVectorIndices,
121  useVectorAlignment);
122 
123  // Architecture specific augmentations.
125  target.addLegalDialect<arith::ArithDialect>();
126  target.addLegalDialect<memref::MemRefDialect>();
127  target.addLegalOp<UnrealizedConversionCastOp>();
128 
129  if (armNeon) {
130  // TODO: we may or may not want to include in-dialect lowering to
131  // LLVM-compatible operations here. So far, all operations in the dialect
132  // can be translated to LLVM IR so there is no conversion necessary.
133  target.addLegalDialect<arm_neon::ArmNeonDialect>();
134  }
135  if (armSVE) {
138  }
139  if (amx) {
142  }
143  if (x86Vector) {
146  }
147 
148  if (failed(
149  applyPartialConversion(getOperation(), target, std::move(patterns))))
150  signalPassFailure();
151 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Options to control the LLVM lowering.
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns)
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
void populateVectorContractToMatrixMultiply(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns)
Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where n > 1.
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populates the pattern set with the following patterns:
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeToFlatTranspose(RewritePatternSet &patterns, PatternBenefit benefit=100)
Populate the pattern set with the following patterns:
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns)
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
void populateVectorToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false, bool useVectorAlignment=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns)
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.