MLIR  16.0.0git
ConvertVectorToLLVMPass.cpp
Go to the documentation of this file.
1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
25 #include "mlir/Pass/Pass.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTVECTORTOLLVM
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::vector;
35 
36 namespace {
37 struct LowerVectorToLLVMPass
38  : public impl::ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
39  LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
40  this->reassociateFPReductions = options.reassociateFPReductions;
41  this->force32BitVectorIndices = options.force32BitVectorIndices;
42  this->armNeon = options.armNeon;
43  this->armSVE = options.armSVE;
44  this->amx = options.amx;
45  this->x86Vector = options.x86Vector;
46  }
47  // Override explicitly to allow conditional dialect dependence.
48  void getDependentDialects(DialectRegistry &registry) const override {
49  registry.insert<LLVM::LLVMDialect>();
50  registry.insert<arith::ArithDialect>();
51  registry.insert<memref::MemRefDialect>();
52  if (armNeon)
53  registry.insert<arm_neon::ArmNeonDialect>();
54  if (armSVE)
55  registry.insert<arm_sve::ArmSVEDialect>();
56  if (amx)
57  registry.insert<amx::AMXDialect>();
58  if (x86Vector)
59  registry.insert<x86vector::X86VectorDialect>();
60  }
61  void runOnOperation() override;
62 };
63 } // namespace
64 
65 void LowerVectorToLLVMPass::runOnOperation() {
66  // Perform progressive lowering of operations on slices and
67  // all contraction operations. Also applies folding and DCE.
68  {
69  RewritePatternSet patterns(&getContext());
76  // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
77  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
78  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
79  }
80 
81  // Convert to the LLVM IR dialect.
82  LLVMTypeConverter converter(&getContext());
83  RewritePatternSet patterns(&getContext());
84  populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
88  converter, patterns, reassociateFPReductions, force32BitVectorIndices);
90 
91  // Architecture specific augmentations.
92  LLVMConversionTarget target(getContext());
93  target.addLegalDialect<arith::ArithDialect>();
94  target.addLegalDialect<memref::MemRefDialect>();
95  target.addLegalOp<UnrealizedConversionCastOp>();
96  if (armNeon) {
97  // TODO: we may or may not want to include in-dialect lowering to
98  // LLVM-compatible operations here. So far, all operations in the dialect
99  // can be translated to LLVM IR so there is no conversion necessary.
100  target.addLegalDialect<arm_neon::ArmNeonDialect>();
101  }
102  if (armSVE) {
104  populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
105  }
106  if (amx) {
108  populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
109  }
110  if (x86Vector) {
113  }
114 
115  if (failed(
116  applyPartialConversion(getOperation(), target, std::move(patterns))))
117  signalPassFailure();
118 }
119 
120 std::unique_ptr<OperationPass<ModuleOp>>
122  return std::make_unique<LowerVectorToLLVMPass>(options);
123 }
Include the generated interface declarations.
void addLegalOp(OperationName op)
Register the given operations as legal.
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collects patterns to progressively lower vector.shape_cast ops on high-D vectors into 1-D/2-D vector ...
Derived class that automatically populates legalization information for different LLVM ops...
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from Vector contractions to LLVM Matrix Intrinsics.
std::unique_ptr< OperationPass< ModuleOp > > createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options=LowerVectorToLLVMOptions())
Create a pass to convert vector operations to the LLVMIR dialect.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, bool reassociateFPReductions=false, bool force32BitVectorIndices=false)
Collect a set of patterns to convert from the Vector dialect to LLVM.
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Collects patterns to progressively lower vector contraction ops on high-D into low-D reduction and pr...
void populateX86VectorLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
Options to control Vector to LLVM lowering.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options=VectorTransformsOptions(), PatternBenefit benefit=1)
Insert TransposeLowering patterns into extraction/insertion.
static llvm::ManagedStatic< PassManagerOptions > options
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collects patterns to progressively lower vector.broadcast ops on high-D vectors to low-D vector ops...
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics. ...
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower AMX ops to ops that map to LLVM intrinsics.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collects patterns to progressively lower vector mask ops into elementary selection and insertion ops...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering AMX ops to ops that map to LLVM intrinsics.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, llvm::Optional< unsigned > maxTransferRank=llvm::None, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns.