MLIR 23.0.0git
IndexedAccessOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- IndexedAccessOpInterfaceImpl.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Implement IndexedAccessOpInterface on vector dialect operations with
9// %memref[%i, %j, ...] operands so generic memref-dialect passes can rewrite
10// their base/index pairs. Transfer ops keep their VectorTransferOpInterface
11// patterns; gather/scatter have tensor-or-memref bases and index-vector
12// operands that do not fit IndexedAccessOpInterface's rank-matched index
13// contract.
14//===----------------------------------------------------------------------===//
15
17
20#include "mlir/IR/Dialect.h"
21#include "mlir/IR/Operation.h"
23
24using namespace mlir;
25using namespace mlir::memref;
26
27namespace {
28/// Return true if this op has the memref semantics expected by this model.
29template <typename LoadStoreOp>
30bool hasMemrefSemantics(Operation *op) {
31 return llvm::isa<MemRefType>(cast<LoadStoreOp>(op).getBase().getType());
32}
33
34/// Return the vector shape whose access strides must be preserved, marking
35/// scalable dimensions as dynamic.
36SmallVector<int64_t> getAccessedVectorShape(VectorType vecTy) {
37 return llvm::map_to_vector(
38 llvm::zip_equal(vecTy.getShape(), vecTy.getScalableDims()), [](auto dim) {
39 auto [size, scalable] = dim;
40 return scalable ? ShapedType::kDynamic : size;
41 });
42}
43
44template <typename LoadStoreOp>
45struct VectorLoadStoreLikeOpImpl final
46 : IndexedAccessOpInterface::ExternalModel<
47 VectorLoadStoreLikeOpImpl<LoadStoreOp>, LoadStoreOp> {
48 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
49 return cast<LoadStoreOp>(op).getBase();
50 }
51
52 Operation::operand_range getIndices(Operation *op) const {
53 return cast<LoadStoreOp>(op).getIndices();
54 }
55
56 SmallVector<int64_t> getAccessedShape(Operation *op) const {
57 assert(hasMemrefSemantics<LoadStoreOp>(op) &&
58 "expected vector op with memref semantics");
59 return getAccessedVectorShape(cast<LoadStoreOp>(op).getVectorType());
60 }
61
62 std::optional<SmallVector<Value>>
63 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
64 ValueRange newIndices) const {
65 assert(hasMemrefSemantics<LoadStoreOp>(op) &&
66 "expected vector op with memref semantics");
67 assert(llvm::isa<MemRefType>(newMemref.getType()) &&
68 "expected replacement memref");
69 rewriter.modifyOpInPlace(op, [&]() {
70 auto concreteOp = cast<LoadStoreOp>(op);
71 concreteOp.getBaseMutable().assign(newMemref);
72 concreteOp.getIndicesMutable().assign(newIndices);
73 });
74 return std::nullopt;
75 }
76
77 // TODO: The various load and store operations, at the very least vector.load
78 // and vector.store, should be taught a starts-in-bounds attribute that would
79 // let us optimize index generation.
80 bool hasInboundsIndices(Operation *op) const {
81 assert(hasMemrefSemantics<LoadStoreOp>(op) &&
82 "expected vector op with memref semantics");
83 return false;
84 }
85};
86
87template <typename... Ops>
88static void attachAll(MLIRContext *ctx) {
89 (Ops::template attachInterface<VectorLoadStoreLikeOpImpl<Ops>>(*ctx), ...);
90}
91
92} // namespace
93
95 DialectRegistry &registry) {
96 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
97 attachAll<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
98 vector::MaskedStoreOp, vector::ExpandLoadOp,
99 vector::CompressStoreOp>(ctx);
100 });
101}
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperandRange operand_range
Definition Operation.h:397
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
void registerIndexedAccessOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494