22 #include "llvm/ADT/ArrayRef.h"
29 static FailureOr<Attribute>
33 if (
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(value)) {
34 if (resType.isScalable() && !isa<SplatElementsAttr>(value))
37 "Cannot linearize a constant scalable vector that's not a splat");
39 return dstElementsAttr.reshape(resType);
42 if (
auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
50 struct LinearizeConstantLike final
67 assert(resType &&
"expected 1-D vector type");
74 FailureOr<Attribute> newValue =
79 FailureOr<Operation *> convertResult =
81 if (failed(convertResult))
85 newOp->
setAttr(attrName, *newValue);
91 struct LinearizeVectorizable final
102 FailureOr<Operation *> newOp =
107 rewriter.
replaceOp(op, (*newOp)->getResults());
112 template <
typename TOp>
113 static bool stridesAllOne(TOp op) {
115 std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116 std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117 "expected vector.extract_strided_slice or vector.insert_strided_slice");
118 ArrayAttr strides = op.getStrides();
123 static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
127 ints.reserve(attrs.size());
128 for (
auto attr : attrs) {
129 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130 ints.push_back(intAttr.getInt());
158 assert((large.size() >= small.size()) &&
159 "rank of 'large' cannot be lower than rank of 'small'");
160 assert((large.size() >= offsets.size()) &&
161 "rank of 'large' cannot be lower than the number of offsets");
162 unsigned delta = large.size() - small.size();
163 unsigned nOffsets = offsets.size();
164 auto getSmall = [&](int64_t i) -> int64_t {
165 return i >= delta ? small[i - delta] : 1;
167 auto getOffset = [&](int64_t i) -> int64_t {
168 return i < nOffsets ? offsets[i] : 0;
176 for (
int i = large.size() - 1; i >= 0; --i) {
177 int64_t currentSize = indices.size();
178 int64_t smallSize = getSmall(i);
179 int64_t nextSize = currentSize * smallSize;
181 int64_t *base = nextIndices.begin();
182 int64_t offset = getOffset(i) * stride;
183 for (
int j = 0;
j < smallSize; ++
j) {
184 for (
int k = 0; k < currentSize; ++k) {
185 base[k] = indices[k] + offset;
191 indices = std::move(nextIndices);
215 struct LinearizeVectorExtractStridedSlice final
218 LinearizeVectorExtractStridedSlice(
const TypeConverter &typeConverter,
224 matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
228 VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
229 extractStridedSliceOp.getType());
230 assert(flatOutputType &&
"vector type expected");
234 if (!stridesAllOne(extractStridedSliceOp)) {
236 extractStridedSliceOp,
237 "extract_strided_slice with strides != 1 not supported");
240 FailureOr<SmallVector<int64_t>> offsets =
241 intsFromArrayAttr(extractStridedSliceOp.getOffsets());
242 if (failed(offsets)) {
244 "failed to get integer offsets");
248 extractStridedSliceOp.getSourceVectorType().getShape();
253 outputShape, inputShape, offsets.value());
255 Value srcVector = adaptor.getVector();
257 extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
286 struct LinearizeVectorInsertStridedSlice final
289 LinearizeVectorInsertStridedSlice(
const TypeConverter &typeConverter,
295 matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
301 if (!stridesAllOne(insertStridedSliceOp)) {
303 insertStridedSliceOp,
304 "insert_strided_slice with strides != 1 not supported");
307 VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
310 VectorType outputType = insertStridedSliceOp.getType();
312 int64_t nOutputElements = outputType.getNumElements();
314 FailureOr<SmallVector<int64_t>> offsets =
315 intsFromArrayAttr(insertStridedSliceOp.getOffsets());
316 if (failed(offsets)) {
318 "failed to get integer offsets");
321 inputShape, outputShape, offsets.value());
324 std::iota(indices.begin(), indices.end(), 0);
326 indices[sliceIndex] = index + nOutputElements;
329 Value flatToStore = adaptor.getValueToStore();
330 Value flatDest = adaptor.getDest();
333 flatToStore, indices);
349 struct LinearizeVectorShuffle final
357 matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
360 getTypeConverter()->convertType<VectorType>(shuffleOp.getType());
361 assert(dstType &&
"vector type destination expected.");
363 Value vec1 = adaptor.getV1();
364 Value vec2 = adaptor.getV2();
365 int shuffleSliceLen = 1;
366 int rank = shuffleOp.getV1().
getType().getRank();
374 for (
unsigned i = 1; i < shape.size(); ++i) {
375 shuffleSliceLen *= shape[i];
384 int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
387 std::iota(indices.begin() + shuffleSliceLen * i,
388 indices.begin() + shuffleSliceLen * (i + 1),
389 shuffleSliceLen * value);
407 struct LinearizeVectorExtract final
414 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
417 if (!isa<VectorType>(extractOp.getType()))
419 "scalar extract not supported");
420 Type dstTy = getTypeConverter()->convertType(extractOp.getType());
421 assert(dstTy &&
"expected 1-D vector type");
424 if (extractOp.hasDynamicPosition())
426 "dynamic position is not supported.");
429 int64_t size = extractOp.getVector().getType().getNumElements();
432 int64_t linearizedOffset = 0;
436 linearizedOffset += offsets[i] * size;
440 std::iota(indices.begin(), indices.end(), linearizedOffset);
442 extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
458 struct LinearizeVectorInsert final
465 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
467 VectorType dstTy = getTypeConverter()->convertType<VectorType>(
468 insertOp.getDestVectorType());
469 assert(dstTy &&
"vector type destination expected.");
472 if (insertOp.hasDynamicPosition())
474 "dynamic position is not supported.");
475 auto srcTy = insertOp.getValueToStoreType();
476 auto srcAsVec = dyn_cast<VectorType>(srcTy);
477 uint64_t srcSize = 0;
479 srcSize = srcAsVec.getNumElements();
482 "scalars are not supported.");
485 auto dstShape = insertOp.getDestVectorType().getShape();
486 const auto dstSize = insertOp.getDestVectorType().getNumElements();
487 auto dstSizeForOffsets = dstSize;
490 int64_t linearizedOffset = 0;
491 auto offsetsNd = insertOp.getStaticPosition();
493 dstSizeForOffsets /= dstShape[dim];
494 linearizedOffset += offset * dstSizeForOffsets;
498 auto *origValsUntil = indices.begin();
499 std::advance(origValsUntil, linearizedOffset);
500 std::iota(indices.begin(), origValsUntil,
502 auto *newValsUntil = origValsUntil;
503 std::advance(newValsUntil, srcSize);
504 std::iota(origValsUntil, newValsUntil,
506 std::iota(newValsUntil, indices.end(),
507 linearizedOffset + srcSize);
511 insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
525 struct LinearizeVectorBitCast final
532 matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
534 auto resType = getTypeConverter()->convertType(castOp.getType());
535 assert(resType &&
"expected 1-D vector type");
537 adaptor.getSource());
538 return mlir::success();
548 struct LinearizeVectorSplat final
557 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
559 auto dstTy = getTypeConverter()->convertType(splatOp.getType());
579 struct LinearizeVectorCreateMask final
583 LinearizeVectorCreateMask(
const TypeConverter &typeConverter,
588 matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
590 Location loc = createMaskOp.getLoc();
591 VectorType srcTy = createMaskOp.getType();
592 auto srcShape = srcTy.getShape();
593 if (srcShape.size() != 2)
595 "only 2D mask is supported.");
597 if (srcShape[0] != 1)
599 createMaskOp,
"only unit outer dimension is supported.");
601 auto dstTy = getTypeConverter()->convertType(srcTy);
609 auto firstOperand = adaptor.getOperands().front();
611 auto isNonZero = rewriter.
createOrFold<mlir::arith::CmpIOp>(
612 loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero);
613 auto isNonZeroIndex = rewriter.
createOrFold<mlir::arith::IndexCastOp>(
615 auto secondOperand = adaptor.getOperands().back();
616 auto maskSize = rewriter.
createOrFold<mlir::arith::AndIOp>(
617 loc, rewriter.
getIndexType(), isNonZeroIndex, secondOperand);
620 rewriter.
create<mlir::vector::CreateMaskOp>(loc, dstTy, maskSize);
621 rewriter.
replaceOp(createMaskOp, newMask);
634 StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace();
636 bool supported = (opDialect == vectorDialect) ||
646 .Case<vector::ShapeCastOp>([&](
auto) {
return false; })
656 .Case<vector::ExtractStridedSliceOp>(
657 [&](vector::ExtractStridedSliceOp extractOp) {
658 return !extractOp.getType().isScalable();
660 .Case<vector::InsertStridedSliceOp>(
661 [&](vector::InsertStridedSliceOp insertOp) {
662 return !insertOp.getType().isScalable();
664 .Case<vector::InsertOp>([&](vector::InsertOp insertOp) {
665 return !insertOp.getType().isScalable();
667 .Case<vector::ExtractOp>([&](vector::ExtractOp extractOp) {
668 return !extractOp.getSourceVectorType().isScalable();
670 .Default([&](
auto) {
return true; });
673 void mlir::vector::populateForVectorLinearize(
TypeConverter &typeConverter,
676 auto convertType = [](
Type type) -> std::optional<Type> {
677 VectorType vectorType = dyn_cast<VectorType>(type);
681 VectorType linearizedType =
683 vectorType.getElementType(), vectorType.isScalable());
684 return linearizedType;
690 if (inputs.size() != 1)
693 Value value = inputs.front();
694 if (!isa<VectorType>(type) || !isa<VectorType>(value.
getType()))
697 return builder.
create<vector::ShapeCastOp>(loc, type, value);
703 [=](
Operation *op) -> std::optional<bool> {
708 return typeConverter.
isLegal(op);
712 void mlir::vector::populateVectorLinearizeBasePatterns(
716 .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
717 LinearizeVectorSplat, LinearizeVectorCreateMask>(
718 typeConverter,
patterns.getContext());
721 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
724 patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
725 LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
726 LinearizeVectorInsertStridedSlice>(typeConverter,
static FailureOr< Attribute > linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, VectorType resType, Attribute value)
static bool isLinearizable(Operation *op)
This method defines the set of operations that are linearizable, and hence that are considered illega...
Attributes are known-constant values of operations.
StringAttr getStringAttr(const Twine &bytes)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpTraitConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting...
OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Include the generated interface declarations.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This trait tags Elementwise operatons that can be systematically vectorized.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.