MLIR  19.0.0git
ArmNeon2dToIntr.cpp
Go to the documentation of this file.
1 //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassRegistry.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_CONVERTARMNEON2DTOINTR
20 #include "mlir/Conversion/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 using namespace mlir::arm_neon;
25 
26 namespace {
27 
28 class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
29 public:
31 
32  /// Convert to 1-dimensional vector type to match the requirements of
33  /// arm.neon.intr.sdot
34  LogicalResult matchAndRewrite(Sdot2dOp op,
35  PatternRewriter &rewriter) const override {
36  Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
37  int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
38  Sdot2dOp::kReductionSize;
39  VectorType flattenedVectorType = VectorType::get({length}, elemType);
40  Value b2d = op.getB();
41  Value c2d = op.getC();
42  Location loc = op.getLoc();
43  Value b1d =
44  rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
45  Value c1d =
46  rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
47  Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
48  b1d, c1d);
49  rewriter.replaceOp(op, {newOp});
50  return success();
51  }
52 };
53 
54 class ConvertArmNeon2dToIntr
55  : public impl::ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> {
56  void runOnOperation() override {
57  auto *context = &getContext();
58 
59  RewritePatternSet patterns(context);
61 
62  if (failed(
63  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
64  return signalPassFailure();
65  }
66 };
67 
68 } // namespace
69 
71  patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
72 }
73 
74 std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() {
75  return std::make_unique<ConvertArmNeon2dToIntr>();
76 }
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns)
Populates patterns for the lowering of Arm NEON 2D ops to intrinsics.
std::unique_ptr< Pass > createConvertArmNeon2dToIntrPass()
Creates a pass to lower Arm NEON 2D ops to intrinsics, i.e.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362