MLIR 22.0.0git
ArmNeon2dToIntr.cpp
Go to the documentation of this file.
1//===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/Pass/Pass.h"
16
17namespace mlir {
18#define GEN_PASS_DEF_CONVERTARMNEON2DTOINTRPASS
19#include "mlir/Conversion/Passes.h.inc"
20} // namespace mlir
21
22using namespace mlir;
23using namespace mlir::arm_neon;
24
25namespace {
26
27class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
28public:
30
31 /// Convert to 1-dimensional vector type to match the requirements of
32 /// arm.neon.intr.sdot
33 LogicalResult matchAndRewrite(Sdot2dOp op,
34 PatternRewriter &rewriter) const override {
35 Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
36 int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
37 Sdot2dOp::kReductionSize;
38 VectorType flattenedVectorType = VectorType::get({length}, elemType);
39 Value b2d = op.getB();
40 Value c2d = op.getC();
41 Location loc = op.getLoc();
42 Value b1d =
43 vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d);
44 Value c1d =
45 vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d);
46 Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(),
47 op.getA(), b1d, c1d);
48 rewriter.replaceOp(op, {newOp});
49 return success();
50 }
51};
52
53class ConvertArmNeon2dToIntr
54 : public impl::ConvertArmNeon2dToIntrPassBase<ConvertArmNeon2dToIntr> {
55 void runOnOperation() override {
56 auto *context = &getContext();
57
58 RewritePatternSet patterns(context);
60
61 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
62 return signalPassFailure();
63 }
64};
65
66} // namespace
67
69 patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
70}
return success()
b getContext())
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns)
Populates patterns for the lowering of Arm NEON 2D ops to intrinsics.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...