MLIR  19.0.0git
ConvertVectorToLLVMPass.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
26 #include "mlir/Pass/Pass.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::vector;
36 
37 namespace {
38 struct LowerVectorToLLVMPass
39  : public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
40 
41  using Base::Base;
42 
43  // Override explicitly to allow conditional dialect dependence.
44  void getDependentDialects(DialectRegistry &registry) const override {
45  registry.insert<LLVM::LLVMDialect>();
46  registry.insert<arith::ArithDialect>();
47  registry.insert<memref::MemRefDialect>();
48  if (armNeon)
49  registry.insert<arm_neon::ArmNeonDialect>();
50  if (armSVE)
51  registry.insert<arm_sve::ArmSVEDialect>();
52  if (amx)
53  registry.insert<amx::AMXDialect>();
54  if (x86Vector)
55  registry.insert<x86vector::X86VectorDialect>();
56  }
57  void runOnOperation() override;
58 };
59 } // namespace
60 
61 void LowerVectorToLLVMPass::runOnOperation() {
62  // Perform progressive lowering of operations on slices and
63  // all contraction operations. Also applies folding and DCE.
64  {
65  RewritePatternSet patterns(&getContext());
74  // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
75  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
76  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
77  }
78 
79  // Convert to the LLVM IR dialect.
81  LLVMTypeConverter converter(&getContext(), options);
82  RewritePatternSet patterns(&getContext());
83  populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
87  converter, patterns, reassociateFPReductions, force32BitVectorIndices);
89 
90  // Architecture specific augmentations.
92  target.addLegalDialect<arith::ArithDialect>();
93  target.addLegalDialect<memref::MemRefDialect>();
94  target.addLegalOp<UnrealizedConversionCastOp>();
95 
96  if (armNeon) {
97  // TODO: we may or may not want to include in-dialect lowering to
98  // LLVM-compatible operations here. So far, all operations in the dialect
99  // can be translated to LLVM IR so there is no conversion necessary.
100  target.addLegalDialect<arm_neon::ArmNeonDialect>();
101  }
102  if (armSVE) {
104  populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
105  }
106  if (amx) {
108  populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
109  }
110  if (x86Vector) {
113  }
114 
115  if (failed(
116  applyPartialConversion(getOperation(), target, std::move(patterns))))
117  signalPassFailure();
118 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
Options to control the LLVM lowering.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, std::optional< unsigned > maxTransferRank=std::nullopt, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.
Structure to control the behavior of vector transform patterns.