MLIR  18.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::arm_sve;
22 
23 template <typename OpTy>
24 class ForwardOperands : public OpConversionPattern<OpTy> {
26 
28  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
29  ConversionPatternRewriter &rewriter) const final {
30  if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
31  return rewriter.notifyMatchFailure(op, "operand types already match");
32 
33  rewriter.updateRootInPlace(
34  op, [&]() { op->setOperands(adaptor.getOperands()); });
35  return success();
36  }
37 };
38 
44  OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
45  ScalableMaskedAddIIntrOp>;
47  OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
48  ScalableMaskedAddFIntrOp>;
50  OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
51  ScalableMaskedSubIIntrOp>;
53  OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
54  ScalableMaskedSubFIntrOp>;
56  OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
57  ScalableMaskedMulIIntrOp>;
59  OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
60  ScalableMaskedMulFIntrOp>;
62  OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
63  ScalableMaskedSDivIIntrOp>;
65  OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
66  ScalableMaskedUDivIIntrOp>;
68  OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
69  ScalableMaskedDivFIntrOp>;
70 
71 namespace {
72 
73 /// Unrolls a conversion to/from equivalent vector types, to allow using a
74 /// conversion intrinsic that only supports 1-D vector types.
75 ///
76 /// Example:
77 /// ```
78 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
79 /// ```
80 /// is rewritten into:
81 /// ```
82 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
83 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
84 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
85 /// : (vector<[4]xi1>) -> vector<[16]xi1>
86 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
87 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
88 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
89 /// : (vector<[4]xi1>) -> vector<[16]xi1>
90 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
91 /// ```
92 template <typename Op, typename IntrOp>
93 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
95 
97  matchAndRewrite(Op convertOp, typename Op::Adaptor,
98  ConversionPatternRewriter &rewriter) const override {
99  auto loc = convertOp.getLoc();
100 
101  auto source = convertOp.getSource();
102  VectorType sourceType = source.getType();
103  VectorType resultType = convertOp.getResult().getType();
104 
105  Value result = rewriter.create<arith::ConstantOp>(
106  loc, resultType, rewriter.getZeroAttr(resultType));
107 
108  // We want to iterate over the input vector in steps of the trailing
109  // dimension. So this creates tile shape where all leading dimensions are 1,
110  // and the trailing dimension step is the size of the dimension.
111  SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
112  tileShape.back() = sourceType.getShape().back();
113 
114  // Iterate over all scalable mask/predicate slices of the source vector.
115  for (SmallVector<int64_t> index :
116  StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
117  auto extractOrInsertPosition = ArrayRef(index).drop_back();
118  auto sourceVector = rewriter.create<vector::ExtractOp>(
119  loc, source, extractOrInsertPosition);
120  VectorType convertedType =
121  VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
122  .setDim(0, resultType.getShape().back());
123  auto convertedVector =
124  rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
125  result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
126  extractOrInsertPosition);
127  }
128 
129  rewriter.replaceOp(convertOp, result);
130  return success();
131  }
132 };
133 
134 using ConvertToSvboolOpLowering =
135  SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
136 
137 using ConvertFromSvboolOpLowering =
138  SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
139 
140 } // namespace
141 
142 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
144  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
145  // Populate conversion patterns
146 
147  // clang-format off
151  &converter.getContext());
152  patterns.add<SdotOpLowering,
165  ConvertToSvboolOpLowering,
166  ConvertFromSvboolOpLowering>(converter);
167  // clang-format on
168 }
169 
171  LLVMConversionTarget &target) {
172  // clang-format off
173  target.addLegalOp<SdotIntrOp,
174  SmmlaIntrOp,
175  UdotIntrOp,
176  UmmlaIntrOp,
177  ScalableMaskedAddIIntrOp,
178  ScalableMaskedAddFIntrOp,
179  ScalableMaskedSubIIntrOp,
180  ScalableMaskedSubFIntrOp,
181  ScalableMaskedMulIIntrOp,
182  ScalableMaskedMulFIntrOp,
183  ScalableMaskedSDivIIntrOp,
184  ScalableMaskedUDivIIntrOp,
185  ScalableMaskedDivFIntrOp,
186  ConvertToSvboolIntrOp,
187  ConvertFromSvboolIntrOp>();
188  target.addIllegalOp<SdotOp,
189  SmmlaOp,
190  UdotOp,
191  UmmlaOp,
192  ScalableMaskedAddIOp,
193  ScalableMaskedAddFOp,
194  ScalableMaskedSubIOp,
195  ScalableMaskedSubFOp,
196  ScalableMaskedMulIOp,
197  ScalableMaskedMulFOp,
198  ScalableMaskedSDivIOp,
199  ScalableMaskedUDivIOp,
200  ScalableMaskedDivFOp,
201  ConvertToSvboolOp,
202  ConvertFromSvboolOp>();
203  // clang-format on
204 }
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< SdotOp, SdotIntrOp > SdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:139
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
MLIRContext & getContext() const
Returns the MLIR context.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:194
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
type_range getTypes() const
Definition: ValueRange.cpp:26
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:236
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:339
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26