MLIR  16.0.0git
VectorPattern.cpp
Go to the documentation of this file.
1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 using namespace mlir;
13 
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
20  LLVMTypeConverter &converter) {
21  assert(vectorType.getRank() > 1 && "expected >1D vector type");
22  NDVectorTypeInfo info;
23  info.llvmNDVectorTy = converter.convertType(vectorType);
25  info.llvmNDVectorTy = nullptr;
26  return info;
27  }
28  info.arraySizes.reserve(vectorType.getRank() - 1);
29  auto llvmTy = info.llvmNDVectorTy;
30  while (llvmTy.isa<LLVM::LLVMArrayType>()) {
31  info.arraySizes.push_back(
32  llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
33  llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
34  }
35  if (!LLVM::isCompatibleVectorType(llvmTy))
36  return info;
37  info.llvm1DVectorTy = llvmTy;
38  return info;
39 }
40 
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
44 //
45 // Prerequisites:
46 // Basis is an array of nonnegative integers (signed type inherited from
47 // vector shape type).
49  unsigned linearIndex) {
51  res.reserve(basis.size());
52  for (unsigned basisElement : llvm::reverse(basis)) {
53  res.push_back(linearIndex % basisElement);
54  linearIndex = linearIndex / basisElement;
55  }
56  if (linearIndex > 0)
57  return {};
58  std::reverse(res.begin(), res.end());
59  return res;
60 }
61 
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
65  OpBuilder &builder,
66  function_ref<void(ArrayRef<int64_t>)> fun) {
67  unsigned ub = 1;
68  for (auto s : info.arraySizes)
69  ub *= s;
70  for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71  auto coords = getCoordinates(info.arraySizes, linearIndex);
72  // Linear index is out of bounds, we are done.
73  if (coords.empty())
74  break;
75  assert(coords.size() == info.arraySizes.size());
76  fun(coords);
77  }
78 }
79 
81  Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
82  std::function<Value(Type, ValueRange)> createOperand,
83  ConversionPatternRewriter &rewriter) {
84  auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
85  auto resultTypeInfo =
86  extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
87  auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
88  auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
89  auto loc = op->getLoc();
90  Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
91  nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
92  // For this unrolled `position` corresponding to the `linearIndex`^th
93  // element, extract operand vectors
94  SmallVector<Value, 4> extractedOperands;
95  for (const auto &operand : llvm::enumerate(operands)) {
96  extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
97  loc, operand.value(), position));
98  }
99  Value newVal = createOperand(result1DVectorTy, extractedOperands);
100  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
101  });
102  rewriter.replaceOp(op, desc);
103  return success();
104 }
105 
107  Operation *op, StringRef targetOp, ValueRange operands,
108  ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
109  ConversionPatternRewriter &rewriter) {
110  assert(!operands.empty());
111 
112  // Cannot convert ops if their operands are not of LLVM type.
113  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
114  return failure();
115 
116  auto llvmNDVectorTy = operands[0].getType();
117  if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
118  return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
119  rewriter);
120 
121  auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
122  ValueRange operands) {
123  return rewriter
124  .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
125  llvm1DVectorTy, targetAttrs)
126  ->getResult(0);
127  };
128 
129  return handleMultidimensionalVectors(op, operands, typeConverter, callback,
130  rewriter);
131 }
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
This class helps build Operations.
Definition: Builders.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:305
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter)
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:857
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:839
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t, 4 > arraySizes
Definition: VectorPattern.h:27
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26