MLIR 23.0.0git
IndexedAccessOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- IndexedAccessOpInterfaceImpl.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Implement IndexedAccessOpInterface on vector dialect operations with
9// %memref[%i, %j, ...] operands so generic memref-dialect passes can rewrite
10// their base/index pairs. Redundant leading unit vector dimensions are omitted
11// from the accessed shape and restored with vector.shape_casts when an alias
12// rewrite drops those dimensions. Transfer ops keep their
13// VectorTransferOpInterface patterns; gather/scatter have tensor-or-memref
14// bases and index-vector operands that do not fit IndexedAccessOpInterface's
15// rank-matched index contract.
16//===----------------------------------------------------------------------===//
17
19
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/Operation.h"
26
27#include <type_traits>
28
29using namespace mlir;
30using namespace mlir::memref;
31
32/// Return true if this op has the memref semantics expected by this model.
33template <typename LoadStoreOp>
34static bool hasMemrefSemantics(Operation *op) {
35 return llvm::isa<MemRefType>(cast<LoadStoreOp>(op).getBase().getType());
36}
37
38/// Return true if this op supports rank-0 vector operands/results.
39template <typename LoadStoreOp>
40static constexpr bool supportsRankZeroVectorAccess() {
41 return std::is_same_v<LoadStoreOp, vector::LoadOp> ||
42 std::is_same_v<LoadStoreOp, vector::StoreOp>;
43}
44
45/// Return the number of leading static unit dimensions in `vecTy`.
46static unsigned getNumLeadingUnitDims(VectorType vecTy) {
47 unsigned numLeadingUnitDims = 0;
48 for (auto [size, scalable] :
49 llvm::zip_equal(vecTy.getShape(), vecTy.getScalableDims())) {
50 if (size != 1 || scalable)
51 break;
52 ++numLeadingUnitDims;
53 }
54 return numLeadingUnitDims;
55}
56
57/// Return the vector shape whose access strides must be preserved, omitting
58/// redundant leading static unit dimensions and marking scalable dimensions as
59/// dynamic. If the op cannot access rank-0 vectors, preserve one trailing unit
60/// dimension instead of returning an empty shape.
62 bool supportsRankZero) {
63 unsigned numLeadingUnitDims = getNumLeadingUnitDims(vecTy);
64 unsigned rank = static_cast<unsigned>(vecTy.getRank());
65 if (!supportsRankZero && numLeadingUnitDims == rank)
66 --numLeadingUnitDims;
67 return llvm::map_to_vector(
68 llvm::zip_equal(vecTy.getShape().drop_front(numLeadingUnitDims),
69 vecTy.getScalableDims().drop_front(numLeadingUnitDims)),
70 [](auto dim) {
71 auto [size, scalable] = dim;
72 return scalable ? ShapedType::kDynamic : size;
73 });
74}
75
76/// Return `vecTy` with `numLeadingDims` dimensions dropped from the front.
77static VectorType dropLeadingDims(VectorType vecTy, unsigned numLeadingDims) {
78 return VectorType::get(vecTy.getShape().drop_front(numLeadingDims),
79 vecTy.getElementType(),
80 vecTy.getScalableDims().drop_front(numLeadingDims));
81}
82
83/// Return the shape-cast type for vector operands that match `vecTy`.
84static std::optional<VectorType>
85getShapeCastTypeForOperand(Value operand, VectorType vecTy,
86 unsigned numLeadingDims) {
87 auto operandTy = dyn_cast<VectorType>(operand.getType());
88 if (!operandTy || operandTy.getShape() != vecTy.getShape() ||
89 operandTy.getScalableDims() != vecTy.getScalableDims())
90 return std::nullopt;
91 return dropLeadingDims(operandTy, numLeadingDims);
92}
93
94namespace {
95template <typename LoadStoreOp>
96struct VectorLoadStoreLikeOpImpl final
97 : IndexedAccessOpInterface::ExternalModel<
98 VectorLoadStoreLikeOpImpl<LoadStoreOp>, LoadStoreOp> {
99 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
100 return cast<LoadStoreOp>(op).getBase();
101 }
102
103 Operation::operand_range getIndices(Operation *op) const {
104 return cast<LoadStoreOp>(op).getIndices();
105 }
106
107 SmallVector<int64_t> getAccessedShape(Operation *op) const {
109 "expected vector op with memref semantics");
110 return getAccessedVectorShape(cast<LoadStoreOp>(op).getVectorType(),
112 }
113
114 std::optional<SmallVector<Value>>
115 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
116 ValueRange newIndices) const {
118 "expected vector op with memref semantics");
119 assert(llvm::isa<MemRefType>(newMemref.getType()) &&
120 "expected replacement memref");
121
122 VectorType vecTy = cast<LoadStoreOp>(op).getVectorType();
123 if (static_cast<int64_t>(newIndices.size()) >= vecTy.getRank()) {
124 rewriter.modifyOpInPlace(op, [&]() {
125 auto concreteOp = cast<LoadStoreOp>(op);
126 concreteOp.getBaseMutable().assign(newMemref);
127 concreteOp.getIndicesMutable().assign(newIndices);
128 });
129 return std::nullopt;
130 }
131
132 unsigned numLeadingDimsToDrop = static_cast<unsigned>(
133 vecTy.getRank() - static_cast<int64_t>(newIndices.size()));
134 assert(numLeadingDimsToDrop <= getNumLeadingUnitDims(vecTy) &&
135 "expected only redundant leading unit dimensions to be dropped");
136
137 IRMapping dropDimsMap;
138 for (Value operand : op->getOperands()) {
139 std::optional<VectorType> castTy =
140 getShapeCastTypeForOperand(operand, vecTy, numLeadingDimsToDrop);
141 if (!castTy || dropDimsMap.lookupOrNull(operand))
142 continue;
143 Value castedOperand = vector::ShapeCastOp::create(
144 rewriter, operand.getLoc(), *castTy, operand);
145 dropDimsMap.map(operand, castedOperand);
146 }
147
148 if (op->getNumResults() == 1) {
149 // Result types cannot be changed in place on the original op because the
150 // caller replaces it using the returned value. Clone at the lower rank,
151 // then cast the result back to the original vector type.
152 VectorType droppedDimsTy = dropLeadingDims(vecTy, numLeadingDimsToDrop);
153 Operation *newOp = rewriter.clone(*op, dropDimsMap);
154 rewriter.modifyOpInPlace(newOp, [&]() {
155 auto concreteOp = cast<LoadStoreOp>(newOp);
156 concreteOp.getBaseMutable().assign(newMemref);
157 concreteOp.getIndicesMutable().assign(newIndices);
158 newOp->getResult(0).setType(droppedDimsTy);
159 });
160 Value castBack = vector::ShapeCastOp::create(rewriter, newOp->getLoc(),
161 vecTy, newOp->getResult(0));
162 return {{castBack}};
163 }
164
165 // Store-like ops have no results to replace, so update their vector
166 // operands and base/index pair in place.
167 rewriter.modifyOpInPlace(op, [&]() {
168 auto concreteOp = cast<LoadStoreOp>(op);
169 concreteOp.getBaseMutable().assign(newMemref);
170 concreteOp.getIndicesMutable().assign(newIndices);
171 for (OpOperand &operand : op->getOpOperands()) {
172 if (Value replacement = dropDimsMap.lookupOrNull(operand.get()))
173 operand.set(replacement);
174 }
175 });
176 return std::nullopt;
177 }
178
179 // TODO: The various load and store operations, at the very least vector.load
180 // and vector.store, should be taught a starts-in-bounds attribute that would
181 // let us optimize index generation.
182 bool hasInboundsIndices(Operation *op) const {
184 "expected vector op with memref semantics");
185 return false;
186 }
187};
188
189template <typename... Ops>
190static void attachLoadStoreLike(MLIRContext *ctx) {
191 (Ops::template attachInterface<VectorLoadStoreLikeOpImpl<Ops>>(*ctx), ...);
192}
193
194} // namespace
195
197 DialectRegistry &registry) {
198 registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
199 attachLoadStoreLike<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
200 vector::MaskedStoreOp, vector::ExpandLoadOp,
201 vector::CompressStoreOp>(ctx);
202 });
203}
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static bool hasMemrefSemantics(Operation *op)
Return true if this op has the memref semantics expected by this model.
static std::optional< VectorType > getShapeCastTypeForOperand(Value operand, VectorType vecTy, unsigned numLeadingDims)
Return the shape-cast type for vector operands that match vecTy.
static constexpr bool supportsRankZeroVectorAccess()
Return true if this op supports rank-0 vector operands/results.
static SmallVector< int64_t > getAccessedVectorShape(VectorType vecTy, bool supportsRankZero)
Return the vector shape whose access strides must be preserved, omitting redundant leading static uni...
static unsigned getNumLeadingUnitDims(VectorType vecTy)
Return the number of leading static unit dimensions in vecTy.
static VectorType dropLeadingDims(VectorType vecTy, unsigned numLeadingDims)
Return vecTy with numLeadingDims dimensions dropped from the front.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:408
OperandRange operand_range
Definition Operation.h:396
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
void registerIndexedAccessOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494