MLIR 23.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
33#include "mlir/IR/Matchers.h"
40
41#include "llvm/ADT/DenseMap.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SetOperations.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/Support/InterleavedRange.h"
50#include "llvm/Support/LogicalResult.h"
51#include "llvm/Support/MathExtras.h"
52#include "llvm/Support/raw_ostream.h"
53#include <cassert>
54#include <optional>
55
56using namespace mlir;
57using namespace mlir::linalg;
58
59/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
61 int64_t dim) {
62 auto type = cast<ShapedType>(v.getType());
63 if (!type.isDynamicDim(dim))
64 return builder.getIndexAttr(type.getDimSize(dim));
65
66 return getAsOpFoldResult(
68 .Case([&](RankedTensorType t) -> Value {
69 return tensor::DimOp::create(builder, loc, v, dim);
70 })
71 .Case([&](MemRefType t) -> Value {
72 return memref::DimOp::create(builder, loc, v, dim);
73 }));
74}
75
76/// Returns a memref.subview or a tensor.extract_slice based on the type of the
77/// `source`.
81 ArrayRef<OpFoldResult> strides) {
83 .Case([&](RankedTensorType t) -> Operation * {
84 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
85 strides);
86 })
87 .Case([&](MemRefType type) -> Operation * {
88 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
89 strides);
90 })
91 .Default([&](Type t) -> Operation * { return nullptr; });
92}
93
94static std::optional<TypedAttr>
96 DenseElementsAttr splatAttr;
98 if (!splatAttr || !splatAttr.isSplat())
99 return std::nullopt;
100
101 return splatAttr.getSplatValue<TypedAttr>();
102}
103
104//===----------------------------------------------------------------------===//
105// Helper functions
106//===----------------------------------------------------------------------===//
107
109 int64_t dim) {
110 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
111 return b.createOrFold<memref::DimOp>(loc, source, dim);
112 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
113 return b.createOrFold<tensor::DimOp>(loc, source, dim);
114 llvm_unreachable("Expected MemRefType or TensorType");
115}
116
118 int64_t dim) {
119 auto shapedType = llvm::cast<ShapedType>(source.getType());
120 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
121 return createOrFoldDimOp(b, loc, source, dim);
122 return b.getIndexAttr(shapedType.getDimSize(dim));
123}
124
125//===----------------------------------------------------------------------===//
126// Support for named Linalg ops defined in ods-gen.
127//===----------------------------------------------------------------------===//
128
132
133/// Fills the region of a structured operation using the provided
134/// `regionBuilder`. The method is used by both named structured ops created by
135/// ods-gen and by manually defined C++ ops. It is called by both builders and
136/// parsers and creates a block with arguments corresponding to the elemental
137/// types of `inputTypes` and `outputTypes`.
138static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
139 TypeRange inputTypes, TypeRange outputTypes,
142 RegionBuilderFn regionBuilder) {
143 SmallVector<Type, 8> argTypes;
145 for (auto containers : {inputTypes, outputTypes}) {
146 for (auto t : containers) {
147 argTypes.push_back(
148 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
149
150 // TODO: Pass in a proper location here.
151 argLocs.push_back(opBuilder.getUnknownLoc());
152 }
153 }
154
155 // RAII.
156 OpBuilder::InsertionGuard guard(opBuilder);
157 Block *body =
158 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
159
160 opBuilder.setInsertionPointToStart(body);
161 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
162 regionBuilder(b, *body, attrs, emitError);
163
164 // indexing_maps is an auto-generated method.
165
166 // iterator_types is an auto-generated method.
167}
168
169/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
170/// The result types are derived automatically if `resultTensorTypes` is none.
171/// The body of the operation is filled using `regionBuilder`. All ods-gen
172/// created structured operations use the method to implement their builders.
174 std::optional<TypeRange> resultTensorTypes,
175 ValueRange inputs, ValueRange outputs,
176 ArrayRef<NamedAttribute> attributes,
177 RegionBuilderFn regionBuilder) {
178 // Derive the result types if needed.
179 SmallVector<Type> derivedResultTypes =
180 resultTensorTypes.value_or(TypeRange());
181 if (!resultTensorTypes)
182 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
183 llvm::IsaPred<RankedTensorType>);
184
185 state.addOperands(inputs);
186 state.addOperands(outputs);
187 state.addTypes(derivedResultTypes);
188
189 state.addAttributes(attributes);
190 state.addAttribute(
191 "operandSegmentSizes",
192 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
193 static_cast<int32_t>(outputs.size())}));
194
195 // Create and fill the region of the structured operation.
196 Region &region = *state.addRegion();
197 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
198 state.attributes.getAttrs(), /*emitError=*/{},
199 regionBuilder);
200}
201
203 std::optional<TypeRange> resultTensorTypes,
204 ValueRange inputs, ValueRange outputs,
205 ArrayRef<NamedAttribute> attributes,
206 RegionBuilderFn regionBuilder,
207 ArrayRef<AffineMap> defaultIndexingMaps) {
208 // If indexing maps are not provided, apply the default ones.
209 if (none_of(attributes, [](NamedAttribute attr) {
210 return attr.getName() == "indexing_maps";
211 })) {
212 SmallVector<Attribute, 3> indexingMapsAttrVal;
213 indexingMapsAttrVal = llvm::map_to_vector(
214 defaultIndexingMaps,
215 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
216 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
217 }
218 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
219 attributes, regionBuilder);
220}
221
223 std::optional<TypeRange> resultTensorTypes,
224 ValueRange inputs, ValueRange outputs,
225 ArrayRef<NamedAttribute> attributes,
226 RegionBuilderFn regionBuilder,
227 ArrayRef<AffineMap> defaultIndexingMaps) {
228 // If indexing maps are not provided, apply the default ones.
229 if (none_of(attributes, [](NamedAttribute attr) {
230 return attr.getName() == "indexing_maps";
231 })) {
232 SmallVector<Attribute, 4> indexingMapsAttrVal;
233 indexingMapsAttrVal = llvm::map_to_vector(
234 defaultIndexingMaps,
235 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
236 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
237 }
238 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
239 attributes, regionBuilder);
240}
241
243 std::optional<TypeRange> resultTensorTypes,
244 ValueRange inputs, ValueRange outputs,
245 ArrayRef<NamedAttribute> attributes,
246 RegionBuilderFn regionBuilder,
247 ArrayRef<AffineMap> indexingMaps) {
248 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
249 SmallVector<Attribute, 4> indexingMapsAttrVal;
250 indexingMapsAttrVal =
251 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
252 return AffineMapAttr::get(map);
253 });
254 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
255 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
256 attributes, regionBuilder);
257}
258
259/// Common parsing used for both named structured ops created by ods-gen and by
260/// manually defined C++ ops. Does not handle regions.
261static ParseResult
263 SmallVectorImpl<Type> &inputTypes,
264 SmallVectorImpl<Type> &outputTypes,
265 bool addOperandSegmentSizes = true) {
266 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
268 outputsOperands;
269
270 if (succeeded(parser.parseOptionalLess())) {
271 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
272 return failure();
273 }
274 attrsLoc = parser.getCurrentLocation();
275 if (parser.parseOptionalAttrDict(result.attributes))
276 return failure();
277
278 if (succeeded(parser.parseOptionalKeyword("ins"))) {
279 if (parser.parseLParen())
280 return failure();
281
282 inputsOperandsLoc = parser.getCurrentLocation();
283 if (parser.parseOperandList(inputsOperands) ||
284 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
285 return failure();
286 }
287
288 if (succeeded(parser.parseOptionalKeyword("outs"))) {
289 outputsOperandsLoc = parser.getCurrentLocation();
290 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
291 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
292 return failure();
293 }
294
295 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
296 result.operands) ||
297 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
298 result.operands))
299 return failure();
300
301 if (addOperandSegmentSizes) {
302 // This is a bit complex because we're trying to be backward compatible with
303 // operation syntax that mix the inherent attributes and the discardable
304 // ones in the same dictionary. If the properties are used, we append the
305 // operandSegmentSizes there directly. Otherwise we append it to the
306 // discardable attributes dictionary where it is handled by the generic
307 // Operation::create(...) method.
308 if (result.propertiesAttr) {
309 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
310 attrs.append("operandSegmentSizes",
312 {static_cast<int32_t>(inputsOperands.size()),
313 static_cast<int32_t>(outputsOperands.size())}));
314 result.propertiesAttr = attrs.getDictionary(parser.getContext());
315 } else {
316 result.addAttribute("operandSegmentSizes",
318 {static_cast<int32_t>(inputsOperands.size()),
319 static_cast<int32_t>(outputsOperands.size())}));
320 }
321 }
322 if (!result.propertiesAttr) {
323 std::optional<RegisteredOperationName> info =
324 result.name.getRegisteredInfo();
325 if (info) {
326 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
327 return parser.emitError(attrsLoc)
328 << "'" << result.name.getStringRef() << "' op ";
329 })))
330 return failure();
331 }
332 }
333 return success();
334}
335
337 ValueRange outputs) {
338 if (!inputs.empty())
339 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
340 if (!outputs.empty())
341 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
342}
343
344//===----------------------------------------------------------------------===//
345// Specific parsing and printing for named structured ops created by ods-gen.
346//===----------------------------------------------------------------------===//
347
349 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
350 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
351 RegionBuilderFn regionBuilder, SMLoc loc) {
352 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
353 return parser.emitError(
354 parser.getCurrentLocation(),
355 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
356 "region expects {0} args, got {1}",
357 numRegionArgs, inputTypes.size() + outputTypes.size()));
358 }
359
360 OpBuilder opBuilder(parser.getContext());
361 ParseResult result = success();
363 opBuilder, region, inputTypes, outputTypes, attrs,
364 [&]() {
365 result = failure();
366 return parser.emitError(loc);
367 },
368 regionBuilder);
369 return result;
370}
371
372static ParseResult
374 SmallVectorImpl<Type> &resultTypes) {
375 if (parser.parseOptionalArrowTypeList(resultTypes))
376 return failure();
377 return success();
378}
379
380static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
382 unsigned numRegionArgs,
383 RegionBuilderFn regionBuilder) {
384 // TODO: Enable when ods-gen supports captures.
385 SmallVector<Type, 1> inputTypes, outputTypes;
386 SMLoc loc = parser.getCurrentLocation();
387 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
388 return failure();
389
390 // Parse optional attributes.
391 if (parser.parseOptionalAttrDict(result.attributes))
392 return failure();
393
394 // TODO: consider merging results parsing into region parsing.
395 // Need to wait for declarative assembly resolution to decide.
396 SmallVector<Type, 1> outputTensorsTypes;
397 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
398 return failure();
399 result.addTypes(outputTensorsTypes);
400
401 std::unique_ptr<Region> region = std::make_unique<Region>();
402 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
403 outputTypes, result.attributes.getAttrs(),
404 regionBuilder, loc))
405 return failure();
406 result.addRegion(std::move(region));
407
408 return success();
409}
410
412 TypeRange resultTypes) {
413 if (resultTypes.empty())
414 return;
415 p.printOptionalArrowTypeList(resultTypes);
416}
417
419 ValueRange inputs, ValueRange outputs,
420 ArrayRef<StringRef> elidedAttrs = {}) {
421 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
422
423 // Printing is shared with generic ops, except for the region and
424 // attributes.
425 printCommonStructuredOpParts(p, inputs, outputs);
426
427 // Results printing.
429
430 // Region is elided.
431}
432
433//===----------------------------------------------------------------------===//
434// Region builder helper.
435// TODO: Move this to a utility library.
436// The public methods on this class are referenced directly from generated code.
437// Helper build the unary, binary, and type conversion functions defined by the
438// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
439// class.
440//
441// Implementations of the math functions must be polymorphic over numeric types,
442// internally performing necessary casts. If the function application makes no
443// sense, then the only recourse is to assert and return nullptr. This can be
444// extended later if it becomes possible to fail construction of the region. The
445// invariant should be enforced at a higher level.
446//
447// TODO: These helpers are currently type polymorphic over the class of integer
448// and floating point types, but they will not internally cast within bit
449// widths of a class (mixed precision such as i8->i32) or across classes
450// (i.e. mixed float and integer). Many such combinations are ambiguous or need
451// to be handled with care and work is being considered to extend the op
452// language to make such cases explicit. In the mean-time, violating this will
453// fail verification, which is deemed acceptable.
454//===----------------------------------------------------------------------===//
455
456namespace {
457
458class RegionBuilderHelper {
459public:
460 RegionBuilderHelper(OpBuilder &builder, Block &block)
461 : builder(builder), block(block) {}
462
463 // Build the unary functions defined by OpDSL.
464 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
465 function_ref<InFlightDiagnostic()> emitError = {}) {
466 if (!isFloatingPoint(arg)) {
467 if (emitError) {
468 emitError() << "unsupported non numeric type";
469 return nullptr;
470 }
471 llvm_unreachable("unsupported non numeric type");
472 }
473 OpBuilder::InsertionGuard g(builder);
474 builder.setInsertionPointToEnd(&block);
475 switch (unaryFn) {
476 case UnaryFn::exp:
477 return math::ExpOp::create(builder, arg.getLoc(), arg);
478 case UnaryFn::log:
479 return math::LogOp::create(builder, arg.getLoc(), arg);
480 case UnaryFn::abs:
481 return math::AbsFOp::create(builder, arg.getLoc(), arg);
482 case UnaryFn::ceil:
483 return math::CeilOp::create(builder, arg.getLoc(), arg);
484 case UnaryFn::floor:
485 return math::FloorOp::create(builder, arg.getLoc(), arg);
486 case UnaryFn::negf:
487 return arith::NegFOp::create(builder, arg.getLoc(), arg);
488 case UnaryFn::reciprocal: {
489 Attribute oneAttr = builder.getOneAttr(arg.getType());
490 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
491 ::cast<TypedAttr>(oneAttr));
492 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
493 }
494 case UnaryFn::round:
495 return math::RoundOp::create(builder, arg.getLoc(), arg);
496 case UnaryFn::sqrt:
497 return math::SqrtOp::create(builder, arg.getLoc(), arg);
498 case UnaryFn::rsqrt:
499 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
500 case UnaryFn::square:
501 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
502 case UnaryFn::tanh:
503 return math::TanhOp::create(builder, arg.getLoc(), arg);
504 case UnaryFn::erf:
505 return math::ErfOp::create(builder, arg.getLoc(), arg);
506 case UnaryFn::sin:
507 return math::SinOp::create(builder, arg.getLoc(), arg);
508 case UnaryFn::cos:
509 return math::CosOp::create(builder, arg.getLoc(), arg);
510 case UnaryFn::tan:
511 return math::TanOp::create(builder, arg.getLoc(), arg);
512 case UnaryFn::acos:
513 return math::AcosOp::create(builder, arg.getLoc(), arg);
514 case UnaryFn::acosh:
515 return math::AcoshOp::create(builder, arg.getLoc(), arg);
516 case UnaryFn::asin:
517 return math::AsinOp::create(builder, arg.getLoc(), arg);
518 case UnaryFn::asinh:
519 return math::AsinhOp::create(builder, arg.getLoc(), arg);
520 case UnaryFn::atan:
521 return math::AtanOp::create(builder, arg.getLoc(), arg);
522 case UnaryFn::atanh:
523 return math::AtanhOp::create(builder, arg.getLoc(), arg);
524 case UnaryFn::log10:
525 return math::Log10Op::create(builder, arg.getLoc(), arg);
526 case UnaryFn::log1p:
527 return math::Log1pOp::create(builder, arg.getLoc(), arg);
528 case UnaryFn::log2:
529 return math::Log2Op::create(builder, arg.getLoc(), arg);
530 }
531 if (emitError) {
532 emitError() << "unsupported unary function";
533 return nullptr;
534 }
535 llvm_unreachable("unsupported unary function");
536 }
537
538 // Build the binary functions defined by OpDSL.
539 // If emitError is provided, an error will be emitted if the operation is not
540 // supported and a nullptr will be returned, otherwise an assertion will be
541 // raised.
542 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
543 function_ref<InFlightDiagnostic()> emitError = {}) {
544 bool allComplex = isComplex(arg0) && isComplex(arg1);
545 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
546 bool allInteger = isInteger(arg0) && isInteger(arg1);
547 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
548 arg1.getType().getIntOrFloatBitWidth() == 1;
549 if (!allComplex && !allFloatingPoint && !allInteger) {
550 if (emitError) {
551 emitError()
552 << "Cannot build binary Linalg operation: expects allComplex, "
553 "allFloatingPoint, or allInteger, got "
554 << arg0.getType() << " and " << arg1.getType();
555 return nullptr;
556 }
557 llvm_unreachable("unsupported non numeric type");
558 }
559 OpBuilder::InsertionGuard g(builder);
560 builder.setInsertionPointToEnd(&block);
561 switch (binaryFn) {
562 case BinaryFn::add:
563 if (allComplex)
564 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
565 if (allFloatingPoint)
566 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
567 if (allBool)
568 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
569 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
570 case BinaryFn::sub:
571 if (allComplex)
572 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
573 if (allFloatingPoint)
574 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
575 if (allBool) {
576 if (emitError) {
577 emitError() << "unsupported operation: sub with bools";
578 return nullptr;
579 }
580 llvm_unreachable("unsupported operation: sub with bools");
581 }
582 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
583 case BinaryFn::mul:
584 if (allComplex)
585 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
586 if (allFloatingPoint)
587 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
588 if (allBool)
589 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
590 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
591 case BinaryFn::div:
592 if (allComplex)
593 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
594 if (allFloatingPoint)
595 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
596 if (allBool) {
597 if (emitError) {
598 emitError() << "unsupported operation: div with bools";
599 return nullptr;
600 }
601 llvm_unreachable("unsupported operation: div with bools");
602 }
603 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
604 case BinaryFn::div_unsigned:
605 if (!allInteger || allBool) {
606 if (emitError) {
607 emitError() << "unsupported operation: unsigned div not on uint";
608 return nullptr;
609 }
610 llvm_unreachable("unsupported operation: unsigned div not on uint");
611 }
612 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
613 case BinaryFn::max_signed:
614 assert(!allComplex);
615 if (allFloatingPoint)
616 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
617 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
618 case BinaryFn::min_signed:
619 assert(!allComplex);
620 if (allFloatingPoint)
621 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
622 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
623 case BinaryFn::max_unsigned:
624 assert(!allComplex);
625 if (!allInteger || allBool) {
626 if (emitError) {
627 emitError() << "unsupported operation: unsigned max not on uint";
628 return nullptr;
629 }
630 llvm_unreachable("unsupported operation: unsigned max not on uint");
631 }
632 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
633 case BinaryFn::min_unsigned:
634 assert(!allComplex);
635 if (!allInteger || allBool) {
636 if (emitError) {
637 emitError() << "unsupported operation: unsigned min not on uint";
638 return nullptr;
639 }
640 llvm_unreachable("unsupported operation: unsigned min not on uint");
641 }
642 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
643 case BinaryFn::powf:
644 assert(allFloatingPoint);
645 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
646 }
647 if (emitError) {
648 emitError() << "unsupported binary function";
649 return nullptr;
650 }
651 llvm_unreachable("unsupported binary function");
652 }
653
654 // Build the ternary functions defined by OpDSL.
655 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
656 function_ref<InFlightDiagnostic()> emitError = {}) {
657 OpBuilder::InsertionGuard g(builder);
658 builder.setInsertionPointToEnd(&block);
659 switch (ternaryFn) {
660 case TernaryFn::select:
661 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
662 }
663 if (emitError) {
664 emitError() << "unsupported ternary function";
665 return nullptr;
666 }
667 llvm_unreachable("unsupported ternary function");
668 }
669
670 // Build the type functions defined by OpDSL.
671 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
672 function_ref<InFlightDiagnostic()> emitError = {}) {
673 switch (typeFn) {
674 case TypeFn::cast_signed:
675 return cast(toType, operand, false);
676 case TypeFn::cast_unsigned:
677 return cast(toType, operand, true);
678 }
679 if (emitError) {
680 emitError() << "unsupported type conversion function";
681 return nullptr;
682 }
683 llvm_unreachable("unsupported type conversion function");
684 }
685
686 void yieldOutputs(ValueRange values) {
687 OpBuilder::InsertionGuard g(builder);
688 builder.setInsertionPointToEnd(&block);
689 Location loc = builder.getUnknownLoc();
690 YieldOp::create(builder, loc, values);
691 }
692
693 Value constant(const std::string &value) {
694 OpBuilder::InsertionGuard g(builder);
695 builder.setInsertionPointToEnd(&block);
696 Location loc = builder.getUnknownLoc();
697 Attribute valueAttr = parseAttribute(value, builder.getContext());
698 return arith::ConstantOp::create(builder, loc,
699 ::cast<TypedAttr>(valueAttr));
700 }
701
702 Value index(int64_t dim) {
703 OpBuilder::InsertionGuard g(builder);
704 builder.setInsertionPointToEnd(&block);
705 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
706 }
707
708 Type getIntegerType(unsigned width) {
709 return IntegerType::get(builder.getContext(), width);
710 }
711
712 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
713 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
714
715private:
716 // Generates operations to cast the given operand to a specified type.
717 // If the cast cannot be performed, a warning will be issued and the
718 // operand returned as-is (which will presumably yield a verification
719 // issue downstream).
720 Value cast(Type toType, Value operand, bool isUnsignedCast) {
721 OpBuilder::InsertionGuard g(builder);
722 builder.setInsertionPointToEnd(&block);
723 auto loc = operand.getLoc();
724 if (isa<UnknownLoc>(loc)) {
725 if (operand.getDefiningOp())
726 loc = operand.getDefiningOp()->getLoc();
727 else if (operand.getParentBlock() &&
728 operand.getParentBlock()->getParentOp())
729 loc = operand.getParentBlock()->getParentOp()->getLoc();
730 }
731 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
732 }
733
734 bool isComplex(Value value) {
735 return llvm::isa<ComplexType>(value.getType());
736 }
737 bool isFloatingPoint(Value value) {
738 return llvm::isa<FloatType>(value.getType());
739 }
740 bool isInteger(Value value) {
741 return llvm::isa<IntegerType>(value.getType());
742 }
743
744 OpBuilder &builder;
745 Block &block;
746};
747
748} // namespace
749
750//===----------------------------------------------------------------------===//
751// CopyOp
752//===----------------------------------------------------------------------===//
753
754namespace {
755
756struct EraseSelfCopy : OpRewritePattern<CopyOp> {
757 using OpRewritePattern<CopyOp>::OpRewritePattern;
758 LogicalResult matchAndRewrite(CopyOp copyOp,
759 PatternRewriter &rewriter) const override {
760 if (copyOp.getInputs() != copyOp.getOutputs())
761 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
762 if (copyOp.hasPureBufferSemantics())
763 rewriter.eraseOp(copyOp);
764 else
765 rewriter.replaceOp(copyOp, copyOp.getInputs());
766
767 return success();
768 }
769};
770
771} // namespace
772
773void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
774 MLIRContext *context) {
775 results.add<EraseSelfCopy>(context);
776}
777
778//===----------------------------------------------------------------------===//
779// FillOp
780//===----------------------------------------------------------------------===//
781
782namespace {
783
784/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
785///
786/// For such op chains, we can create new linalg.fill ops with the result
787/// type of the tensor.expand/collapse_shape op.
788template <typename TensorReshapeOp>
789struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
790 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
791 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
792 PatternRewriter &rewriter) const override {
793 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
794 if (!oldFill)
795 return failure();
796
797 Location loc = oldFill.getLoc();
798 TensorReshapeOp newInit;
799 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
800
801 newInit = TensorReshapeOp::create(
802 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
803 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
804 reshapeOp.getStaticOutputShape());
805 } else {
806 newInit = TensorReshapeOp::create(
807 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
808 reshapeOp.getReassociation());
809 }
810 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
811 ValueRange{newInit});
812 return success();
813 }
814};
815
816/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
817/// filling value are the same.
818struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
820
821 LogicalResult matchAndRewrite(tensor::PadOp padOp,
822 PatternRewriter &rewriter) const override {
823 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
824 if (!fillOp)
825 return failure();
826
827 // We can only fold if the padding value is the same as the original
828 // filling value.
829 Value padValue = padOp.getConstantPaddingValue();
830 if (!padValue || fillOp.value() != padValue)
831 return failure();
832
833 ReifiedRankedShapedTypeDims reifiedShape;
834 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
835 return rewriter.notifyMatchFailure(
836 padOp, "failed to reify tensor.pad op result shape");
837
838 auto emptyTensor =
839 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
840 padOp.getResultType().getElementType());
841 Value replacement =
842 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
843 ValueRange{emptyTensor})
844 .getResult(0);
845 if (replacement.getType() != padOp.getResultType()) {
846 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
847 padOp.getResultType(), replacement);
848 }
849 rewriter.replaceOp(padOp, replacement);
850 return success();
851 }
852};
853
854/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
855/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
856/// filling value are the same.
857struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
859
860 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
861 PatternRewriter &rewriter) const override {
862 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
863 if (!srcPadOp)
864 return failure();
865
866 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
867 return failure();
868
869 // Walk back the tensor.insert_slice chain and find the first destination
870 // value at the start of the chain.
871 Value firstDest = insertOp.getDest();
872 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
873 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
874 return failure();
875
876 // Make sure the range of values accessed are disjoint. Without this, we
877 // cannot fold tensor.pad away.
878 bool disjoint = false;
879 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
880 // If the dimension has dynamic offset/size, we cannot guarantee
881 // disjoint. So just skip it.
882 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
883 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
884 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
885 continue;
886
887 // Get the range start and end, inclusively for both.
888 int64_t prevStart = prevOp.getStaticOffset(i);
889 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
890 prevOp.getStaticStride(i);
891 int64_t nextStart = insertOp.getStaticOffset(i);
892 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
893 insertOp.getStaticStride(i);
894 if (prevEnd < nextStart || nextEnd < prevStart) {
895 disjoint = true;
896 break;
897 }
898 }
899
900 if (!disjoint)
901 break;
902 firstDest = prevOp.getDest();
903 }
904
905 // Check whether the first destination is a fill op. For overlapped cases,
906 // this also cannot be true.
907 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
908 if (!dstFillOp)
909 return failure();
910
911 // We can only fold if the padding value is the same as the original
912 // filling value.
913 Value padValue = srcPadOp.getConstantPaddingValue();
914 if (!padValue || dstFillOp.value() != padValue)
915 return failure();
916
917 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
918 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
919
920 Location loc = insertOp.getLoc();
921 MLIRContext *context = getContext();
922
923 AffineExpr sym0, sym1;
924 bindSymbols(context, sym0, sym1);
925 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
926
927 // Calculate the new offsets for the insert. It should be the old offsets
928 // plus low padding sizes.
929 SmallVector<OpFoldResult, 4> newOffsets;
930 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
931 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
932 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
933 }
934
935 RankedTensorType srcPadType = srcPadOp.getSourceType();
936 SmallVector<OpFoldResult, 4> newSizes;
937 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
938 if (srcPadType.isDynamicDim(i)) {
939 newSizes.push_back(
940 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
941 .getResult());
942 } else {
943 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
944 }
945 }
946
947 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
948 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
949 newSizes, insertOp.getMixedStrides());
950 return success();
951 }
952};
953
954/// Fold tensor.extract(linalg.fill(<input>)) into <input>
955struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
956public:
957 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
958
959 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
960 PatternRewriter &rewriter) const override {
961 // See if tensor input of tensor.extract op is the result of a linalg.fill
962 // op.
963 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
964 if (!fillOp)
965 return failure();
966
967 // Get scalar input operand of linalg.fill op.
968 Value extractedScalar = fillOp.getInputs()[0];
969
970 // Replace tensor.extract op with scalar value used to fill the tensor.
971 rewriter.replaceOp(extractOp, extractedScalar);
972 return success();
973 }
974};
975
976/// Folds pack(fill) into a single fill op if
977/// 1. The pack op does not have padding value, or
978/// 2. The filled value and padding value are the same.
979static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
980 linalg::PackOp packOp) {
981 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
982 if (!fillOp)
983 return failure();
984
985 if (auto paddingValue = packOp.getPaddingValue())
986 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
987 return failure();
988
989 Value packOpDest = packOp.getDest();
990 if (!packOpDest.hasOneUse())
991 return failure();
992
993 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
994 packOp.getDest());
995}
996
997/// Wrapper pattern that applies foldFillPackIntoFillOp method.
998struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
999public:
1000 FoldFillWithPack(MLIRContext *context)
1001 : OpRewritePattern<linalg::PackOp>(context) {}
1002
1003 LogicalResult matchAndRewrite(linalg::PackOp packOp,
1004 PatternRewriter &rewriter) const override {
1005 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
1006 if (failed(fillOp))
1007 return failure();
1008 rewriter.replaceOp(packOp, fillOp.value().result());
1009 return success();
1010 }
1011};
1012
1013/// Fold fill with copy.
1014struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
1015 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
1016
1017 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
1018 PatternRewriter &rewriter) const override {
1019 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
1020 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
1021 fillOp.getInputs(),
1022 copyOp.getOutputs());
1023 return success();
1024 }
1025 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
1026 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
1027 fillOp.getOutputs());
1028 return success();
1029 }
1030 return failure();
1031 }
1032};
1033
1034/// Fold fill with transpose.
1035struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1036 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1037
1038 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1039 PatternRewriter &rewriter) const override {
1040 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1041 rewriter.replaceOpWithNewOp<FillOp>(
1042 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1043 transposeOp.getDpsInitOperand(0)->get());
1044 return success();
1045 }
1046 return failure();
1047 }
1048};
1049
1050/// Fold a concat with all elements being fills of the same value
1051/// into a fill of the concat result shape.
1052struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1054
1055 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1056 PatternRewriter &rewriter) const override {
1057 auto concatOperands = concatOp.getInputs();
1058 if (concatOperands.empty()) {
1059 return failure();
1060 }
1061
1062 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1063 if (!firstFillOp) {
1064 return failure();
1065 }
1066 // Prefetch the fill value.
1067 OpFoldResult firstFillVal =
1068 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1069 // Collect all the outs values for the fill operations.
1070 SmallVector<Value> allOuts;
1071 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1072
1073 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1074 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1075 if (!fillOp) {
1076 return false;
1077 }
1078
1079 OpFoldResult fillVal =
1080 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1081 if (fillVal != firstFillVal)
1082 return false;
1083
1084 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1085 return true;
1086 };
1087 if (!llvm::all_of(concatOperands.drop_front(),
1088 isDefinedByCompatibleFillOp)) {
1089 return rewriter.notifyMatchFailure(
1090 concatOp, "not all operands are defined by a compatible fill op");
1091 }
1092
1093 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1094 concatOp.getDim(), allOuts);
1095 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1096 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1097 return success();
1098 }
1099};
1100
1101} // namespace
1102
1103void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1104 MLIRContext *context) {
1105 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1106 FoldFillWithPack, FoldFillWithPad,
1107 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1108 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1109 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1110}
1111
1112//===----------------------------------------------------------------------===//
1113// GenericOp
1114//===----------------------------------------------------------------------===//
1115
1117 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1118 ValueRange outputs,
1119 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1120 SmallVector<Type, 4> blockArgTypes;
1121 SmallVector<Location, 4> blockArgLocs;
1122 for (ValueRange container : {inputs, outputs}) {
1123 for (Value v : container) {
1124 Type t = v.getType();
1125 blockArgTypes.push_back(
1126 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1127 blockArgLocs.push_back(v.getLoc());
1128 }
1129 }
1130
1131 OpBuilder::InsertionGuard guard(builder);
1132 Block *bodyBlock =
1133 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1134 bodyBuild(builder, loc, bodyBlock->getArguments());
1135}
1136
1137void GenericOp::getAsmBlockArgumentNames(Region &region,
1138 OpAsmSetValueNameFn setNameFn) {
1139 for (Value v : getRegionInputArgs())
1140 setNameFn(v, "in");
1141 for (Value v : getRegionOutputArgs())
1142 setNameFn(v, "out");
1143}
1144
1145void GenericOp::build(
1146 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1147 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1148 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1149 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1150 ArrayRef<NamedAttribute> attributes) {
1151 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1152 iteratorTypes, doc, libraryCall);
1153 result.addAttributes(attributes);
1154 if (bodyBuild)
1155 buildGenericRegion(builder, result.location, *result.regions.front(),
1156 inputs, outputs, bodyBuild);
1157}
1158
1159void GenericOp::build(
1160 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1161 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1162 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1163 StringRef libraryCall,
1164 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1165 ArrayRef<NamedAttribute> attributes) {
1166 build(builder, result, resultTensorTypes, inputs, outputs,
1167 builder.getAffineMapArrayAttr(indexingMaps),
1168 builder.getArrayAttr(llvm::map_to_vector(
1169 iteratorTypes,
1170 [&](utils::IteratorType iter) -> mlir::Attribute {
1171 return IteratorTypeAttr::get(builder.getContext(), iter);
1172 })),
1173 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1174 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1175 bodyBuild, attributes);
1176}
1177
1178void GenericOp::build(
1179 OpBuilder &builder, OperationState &result, ValueRange inputs,
1180 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1181 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1182 StringRef libraryCall,
1183 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1184 ArrayRef<NamedAttribute> attributes) {
1185 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1186 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1187}
1188
1189void GenericOp::build(
1190 OpBuilder &builder, OperationState &result, ValueRange inputs,
1191 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1192 ArrayRef<utils::IteratorType> iteratorTypes,
1193 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1194 ArrayRef<NamedAttribute> attributes) {
1195 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1196 /*doc=*/"",
1197 /*libraryCall=*/"", bodyBuild, attributes);
1198}
1199
1200void GenericOp::build(
1201 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1202 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1203 ArrayRef<utils::IteratorType> iteratorTypes,
1204 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1205 ArrayRef<NamedAttribute> attributes) {
1206 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1207 iteratorTypes,
1208 /*doc=*/"",
1209 /*libraryCall=*/"", bodyBuild, attributes);
1210}
1211
1212void GenericOp::print(OpAsmPrinter &p) {
1213 p << " ";
1214
1215 // Print extra attributes.
1216 auto genericAttrNames = linalgTraitAttrNames();
1217
1218 llvm::StringSet<> genericAttrNamesSet;
1219 genericAttrNamesSet.insert_range(genericAttrNames);
1220 SmallVector<NamedAttribute, 8> genericAttrs;
1221 for (auto attr : (*this)->getAttrs()) {
1222 if (attr.getName() == getIteratorTypesAttrName()) {
1223 auto iteratorTypes =
1224 llvm::cast<ArrayAttr>(attr.getValue())
1225 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1226 // Convert IteratorType enums into the string representation. This is
1227 // needed, because tests still use the old format when 'iterator_types'
1228 // attribute is represented as an array of strings.
1229 // TODO: Remove this conversion once tests are fixed.
1230 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1231 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1232 return StringAttr::get(getContext(), stringifyIteratorType(t));
1233 });
1234
1235 genericAttrs.emplace_back(
1236 getIteratorTypesAttrName(),
1237 ArrayAttr::get(getContext(), iteratorTypeNames));
1238 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1239 genericAttrs.push_back(attr);
1240 }
1241 }
1242 if (!genericAttrs.empty()) {
1243 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1244 p << genericDictAttr;
1245 }
1246
1247 // Printing is shared with named ops, except for the region and attributes
1248 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1249
1250 genericAttrNames.push_back("operandSegmentSizes");
1251 genericAttrNamesSet.insert(genericAttrNames.back());
1252
1253 bool hasExtraAttrs = false;
1254 for (NamedAttribute n : (*this)->getAttrs()) {
1255 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1256 break;
1257 }
1258 if (hasExtraAttrs) {
1259 p << " attrs = ";
1260 p.printOptionalAttrDict((*this)->getAttrs(),
1261 /*elidedAttrs=*/genericAttrNames);
1262 }
1263
1264 // Print region.
1265 if (!getRegion().empty()) {
1266 p << ' ';
1267 p.printRegion(getRegion());
1268 }
1269
1270 // Print results.
1271 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1272}
1273
1274ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1275 DictionaryAttr dictAttr;
1276 // Parse the core linalg traits that must check into a dictAttr.
1277 // The name is unimportant as we will overwrite result.attributes.
1278 // The core linalg traits must contain the information necessary to pass the
1279 // verifier.
1280 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1281 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1282 return failure();
1283 result.attributes.assign(dictAttr.getValue().begin(),
1284 dictAttr.getValue().end());
1285
1286 // Convert array of string into an array of IteratorType enums. This is
1287 // needed, because tests still use the old format when 'iterator_types'
1288 // attribute is represented as an array of strings.
1289 // TODO: Remove this conversion once tests are fixed.
1290 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1291 result.attributes.get(getIteratorTypesAttrName(result.name)));
1292 if (!iteratorTypes) {
1293 return parser.emitError(attributeLocation)
1294 << "expected " << getIteratorTypesAttrName(result.name)
1295 << " array attribute";
1296 }
1297
1298 SmallVector<Attribute> iteratorTypeAttrs;
1299
1300 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1301 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1302 if (!maybeIteratorType.has_value())
1303 return parser.emitError(parser.getCurrentLocation())
1304 << "unexpected iterator_type (" << s << ")";
1305
1306 iteratorTypeAttrs.push_back(
1307 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1308 }
1309 result.attributes.set(getIteratorTypesAttrName(result.name),
1310 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1311
1312 // Parsing is shared with named ops, except for the region.
1313 SmallVector<Type, 1> inputTypes, outputTypes;
1314 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1315 return failure();
1316
1317 // Optional attributes may be added.
1318 if (succeeded(parser.parseOptionalKeyword("attrs")))
1319 if (failed(parser.parseEqual()) ||
1320 failed(parser.parseOptionalAttrDict(result.attributes)))
1321 return failure();
1322
1323 std::unique_ptr<Region> region = std::make_unique<Region>();
1324 if (parser.parseRegion(*region, {}))
1325 return failure();
1326 result.addRegion(std::move(region));
1327
1328 // Generic ops may specify that a subset of its outputs are tensors. Such
1329 // outputs are specified in the result type.
1330 // TODO: may need to move output parsing before region parsing.
1331 // Need to wait for declarative assembly resolution to decide.
1332 SmallVector<Type, 1> outputTensorsTypes;
1333 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1334 return failure();
1335 result.addTypes(outputTensorsTypes);
1336
1337 return success();
1338}
1339
1342 &effects,
1343 LinalgOp linalgOp) {
1344 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1345 if (!llvm::isa<MemRefType>(operand.getType()))
1346 continue;
1347 effects.emplace_back(
1348 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1349 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1350 }
1351
1352 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1353 if (!llvm::isa<MemRefType>(operand.get().getType()))
1354 continue;
1355 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1356 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1357 /*effectOnFullRegion=*/true,
1359 }
1360 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1361 /*effectOnFullRegion=*/true,
1363 }
1364}
1365
1366void GenericOp::getEffects(
1367 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1368 &effects) {
1369 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1370}
1371
1374 // Operands with value semantics are speculatable, while operands with memory
1375 // semantics are not.
1376 if (!linalgOp.hasPureTensorSemantics())
1378 // The body of the op can still have speculation in its region.
1380}
1381
1382Speculation::Speculatability GenericOp::getSpeculatability() {
1383 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1384}
1385
1386namespace {
1387
1388/// Remove linalg operations that are just copying the values from inputs to
1389/// results. In the memref case, the operation must be copying to and from the
1390/// same value. Requirements are:
1391/// 1) All iterator types are parallel
1392/// 2) The body contains just a yield operation with the yielded values being
1393/// the arguments corresponding to the operands.
1394template <typename OpTy>
1395struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1396 using OpRewritePattern<OpTy>::OpRewritePattern;
1397
1398 LogicalResult matchAndRewrite(OpTy linalgOp,
1399 PatternRewriter &rewriter) const override {
1400 // All indexing maps must be equal. It follows that they are permutations.
1401 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1402 return failure();
1403
1404 // Check that the body of the linalg operation is just a linalg.yield
1405 // operation.
1406 Block &body = linalgOp->getRegion(0).front();
1407 if (!llvm::hasSingleElement(body))
1408 return failure();
1409 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1410 if (!yieldOp)
1411 return failure();
1412
1413 // In the buffer case, we need to check exact buffer equality.
1414 if (linalgOp.hasPureBufferSemantics()) {
1415 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1416 linalgOp.getDpsInputOperand(0)->get() !=
1417 linalgOp.getDpsInitOperand(0)->get()) {
1418 return rewriter.notifyMatchFailure(
1419 linalgOp, "expected single input and output to be the same value");
1420 }
1421
1422 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1423 if (!yieldArg || yieldArg.getOwner() != &body) {
1424 return rewriter.notifyMatchFailure(linalgOp,
1425 "cannot fold fill-like op");
1426 }
1427
1428 rewriter.eraseOp(linalgOp);
1429 return success();
1430 }
1431
1432 if (!linalgOp.hasPureTensorSemantics()) {
1433 return rewriter.notifyMatchFailure(
1434 linalgOp, "mixed semantics is not supported yet");
1435 }
1436
1437 // Get the argument number of the returned values. That is the operand
1438 // number to use for replacing uses of this operation.
1439 SmallVector<Value> returnedArgs;
1440 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1441 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1442 if (!yieldArg || yieldArg.getOwner() != &body)
1443 return failure();
1444 unsigned argumentNumber = yieldArg.getArgNumber();
1445 Value returnedArg = linalgOp->getOperand(argumentNumber);
1446 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1447 // The input can have a different type than the result, e.g. a dynamic
1448 // input dimension can be turned into a static output dimension.
1449 Type returnType = returnedArg.getType();
1450 if (returnType != resultType) {
1451 // Distinguish between sparse conversion or dense tensor casting.
1452 // TODO: unify the two ops?
1455 returnedArg = sparse_tensor::ConvertOp::create(
1456 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1457 else {
1458 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1459 resultType))
1460 return failure();
1461 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1462 resultType, returnedArg);
1463 }
1464 }
1465 returnedArgs.push_back(returnedArg);
1466 }
1467
1468 if (returnedArgs.size() != linalgOp->getNumResults())
1469 return failure();
1470 rewriter.replaceOp(linalgOp, returnedArgs);
1471 return success();
1472 }
1473};
1474
1475} // namespace
1476
1477void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1478 MLIRContext *context) {
1479 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1480}
1481
1482LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1483 return memref::foldMemRefCast(*this);
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// MapOp
1488//===----------------------------------------------------------------------===//
1489
1490static ParseResult parseDstStyleOp(
1492 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1493 nullptr) {
1494 // Parse `ins` and `outs`.
1495 SmallVector<Type, 4> inputTypes, outputTypes;
1496 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1497 /*addOperandSegmentSizes=*/false))
1498 return failure();
1499
1500 // Add result types.
1501 for (Type outputType : outputTypes) {
1502 if (llvm::isa<RankedTensorType>(outputType))
1503 result.addTypes(outputType);
1504 }
1505
1506 // Parse required attributes.
1507 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1508 return failure();
1509
1510 // Parse optional attributes.
1511 if (parser.parseOptionalAttrDict(result.attributes))
1512 return failure();
1513 return success();
1514}
1515
1516void MapOp::getAsmBlockArgumentNames(Region &region,
1517 OpAsmSetValueNameFn setNameFn) {
1518 for (Value v : getRegionInputArgs())
1519 setNameFn(v, "in");
1520 for (Value v : getRegionOutputArgs())
1521 setNameFn(v, "init");
1522}
1523
1524void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1525 if (!getResults().empty())
1526 setNameFn(getResults().front(), "mapped");
1527}
1528
1529void MapOp::build(
1530 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1531 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1532 ArrayRef<NamedAttribute> attributes) {
1533 build(builder, result, TypeRange{}, inputs, init);
1534 result.addAttributes(attributes);
1535
1536 // Add output types for `RankedTensorType` output arguments.
1537 Type initType = init.getType();
1538 if (llvm::isa<RankedTensorType>(initType))
1539 result.addTypes(initType);
1540
1541 if (bodyBuild)
1542 buildGenericRegion(builder, result.location, *result.regions.front(),
1543 inputs, /*outputs=*/{init}, bodyBuild);
1544}
1545
1547 const OperationName &payloadOpName,
1548 const NamedAttrList &payloadOpAttrs,
1549 ArrayRef<Value> operands,
1550 bool initFirst = false, bool mapInit = true) {
1551 OpBuilder b(parser.getContext());
1552 Region *body = result.addRegion();
1553 Block &block = body->emplaceBlock();
1554 b.setInsertionPointToStart(&block);
1555 for (auto &operand : operands) {
1556 block.addArgument(
1557 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1558 b.getUnknownLoc());
1559 }
1560 SmallVector<Value> payloadOpOperands;
1561 // If initFirst flag is enabled, we consider init as the first position of
1562 // payload operands.
1563 if (initFirst) {
1564 if (mapInit)
1565 payloadOpOperands.push_back(block.getArguments().back());
1566 for (const auto &arg : block.getArguments().drop_back())
1567 payloadOpOperands.push_back(arg);
1568 } else {
1569 payloadOpOperands = {block.getArguments().begin(),
1570 block.getArguments().end() - int(!mapInit)};
1571 }
1572
1573 Operation *payloadOp = b.create(
1574 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1575 payloadOpOperands,
1576 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1577 .getElementType()},
1578 payloadOpAttrs);
1579 YieldOp::create(b, result.location, payloadOp->getResults());
1580}
1581
1582ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1583 std::optional<OperationName> payloadOpName;
1584 NamedAttrList payloadOpAttrs;
1585 if (succeeded(parser.parseOptionalLBrace())) {
1586 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1587 if (failed(operationName))
1588 return failure();
1589 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1590 return failure();
1591 payloadOpName = operationName.value();
1592 if (parser.parseRBrace())
1593 return failure();
1594 }
1595
1596 if (parseDstStyleOp(parser, result))
1597 return failure();
1598
1599 if (payloadOpName.has_value()) {
1600 if (!result.operands.empty())
1601 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1602 payloadOpAttrs, ArrayRef(result.operands), false,
1603 false);
1604 else
1605 result.addRegion();
1606 } else {
1607 SmallVector<OpAsmParser::Argument> regionArgs;
1608 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1609 /*allowType=*/true, /*allowAttrs=*/true)) {
1610 return failure();
1611 }
1612 Region *body = result.addRegion();
1613 if (parser.parseRegion(*body, regionArgs))
1614 return failure();
1615 }
1616 return success();
1617}
1618
1619static bool canUseShortForm(Block *body, bool initFirst = false,
1620 bool mapInit = true) {
1621 // `intFirst == true` implies that we want to map init arg
1622 if (initFirst && !mapInit)
1623 return false;
1624 // Check if the body can be printed in short form. The following 4 conditions
1625 // must be satisfied:
1626
1627 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1628 if (body->getOperations().size() != 2)
1629 return false;
1630 Operation &payload = body->getOperations().front();
1631
1632 // 2) The payload op must have the same number of operands as the number of
1633 // block arguments.
1634 if (payload.getNumOperands() == 0 ||
1635 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1636 return false;
1637
1638 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1639 // must be the first operand of the payload op, otherwise, the operands
1640 // must match the block arguments in order.
1641 if (initFirst) {
1642 // check init
1643 if (payload.getOperands().back() != body->getArgument(0))
1644 return false;
1645 // check rest
1646 for (const auto &[operand, bbArg] :
1647 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1648 if (bbArg != operand)
1649 return false;
1650 }
1651 } else {
1652 for (const auto &[operand, bbArg] :
1653 llvm::zip(payload.getOperands(),
1654 body->getArguments().drop_back(int(!mapInit)))) {
1655 if (bbArg != operand)
1656 return false;
1657 }
1658 }
1659
1660 // 4) The `yield` operand must be the result of the payload op.
1661 auto yieldOp = cast<YieldOp>(body->getTerminator());
1662 return yieldOp.getNumOperands() == 1 &&
1663 yieldOp.getOperand(0).getDefiningOp() &&
1664 yieldOp.getOperand(0).getDefiningOp() == &payload;
1665}
1666
1667static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1668 SmallVector<StringRef> elidedAttrs;
1669 std::string attrToElide;
1670 p << " { " << payloadOp->getName().getStringRef();
1671 for (const auto &attr : payloadOp->getAttrs()) {
1672 auto fastAttr =
1673 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1674 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1675 attrToElide = attr.getName().str();
1676 elidedAttrs.push_back(attrToElide);
1677 break;
1678 }
1679 }
1680 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1681 p << " }";
1682}
1683
1684void MapOp::print(OpAsmPrinter &p) {
1685 Block *mapper = getBody();
1686 bool useShortForm =
1687 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1688 if (useShortForm) {
1689 printShortForm(p, &mapper->getOperations().front());
1690 }
1691
1692 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1693 p.printOptionalAttrDict((*this)->getAttrs());
1694
1695 if (!useShortForm) {
1696 // Print region if the payload op was not detected.
1697 p.increaseIndent();
1698 p.printNewline();
1699 p << "(";
1700 llvm::interleaveComma(mapper->getArguments(), p,
1701 [&](auto arg) { p.printRegionArgument(arg); });
1702 p << ") ";
1703
1704 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1705 p.decreaseIndent();
1706 }
1707}
1708
1709LogicalResult MapOp::verify() {
1710 auto *bodyBlock = getBody();
1711 auto blockArgs = bodyBlock->getArguments();
1712
1713 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1714 // region.
1715 if (getInputs().size() + 1 != blockArgs.size())
1716 return emitOpError() << "expects number of operands to match the arity of "
1717 "mapper, but got: "
1718 << getInputs().size() + 1 << " and "
1719 << blockArgs.size();
1720
1721 // The parameters of mapper should all match the element type of inputs.
1722 for (const auto &[bbArgType, inputArg] :
1723 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1724 auto inputElemType =
1725 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1726 if (bbArgType != inputElemType) {
1727 return emitOpError() << "expected element type of input " << inputElemType
1728 << " to match bbArg type " << bbArgType;
1729 }
1730 }
1731
1732 // The shape of each input must match the shape of the output.
1733 auto outputShape = getInit().getType().getShape();
1734 for (Type inputArgType : TypeRange{getInputs()}) {
1735 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1736 if (inputElemShape != outputShape) {
1737 return emitOpError() << "expected shape of input (" << inputElemShape
1738 << ") to match shape of output (" << outputShape
1739 << ")";
1740 }
1741 }
1742
1743 return success();
1744}
1745
1746SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1747 int64_t rank = getInit().getType().getRank();
1748 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1749}
1750
1751ArrayAttr MapOp::getIndexingMaps() {
1752 Builder builder(getContext());
1753 int64_t rank = getInit().getType().getRank();
1754 int64_t numIndexingMaps = getOperands().size();
1755 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1756 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1757}
1758
1759void MapOp::getEffects(
1760 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1761 &effects) {
1762 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1763}
1764
1765Speculation::Speculatability MapOp::getSpeculatability() {
1766 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1767}
1768
1769//===----------------------------------------------------------------------===//
1770// ReduceOp
1771//===----------------------------------------------------------------------===//
1772
1773void ReduceOp::getAsmBlockArgumentNames(Region &region,
1774 OpAsmSetValueNameFn setNameFn) {
1775 for (Value v : getRegionInputArgs())
1776 setNameFn(v, "in");
1777 for (Value v : getRegionOutputArgs())
1778 setNameFn(v, "init");
1779}
1780
1781void ReduceOp::getAsmResultNames(
1782 function_ref<void(Value, StringRef)> setNameFn) {
1783 if (!getResults().empty())
1784 setNameFn(getResults().front(), "reduced");
1785}
1786
1787void ReduceOp::build(
1788 OpBuilder &builder, OperationState &result, ValueRange inputs,
1789 ValueRange inits, ArrayRef<int64_t> dimensions,
1790 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1791 ArrayRef<NamedAttribute> attributes) {
1792 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1793 result.addAttributes(attributes);
1794
1795 // Add output types for `RankedTensorType` output arguments.
1796 for (Value init : inits) {
1797 Type initType = init.getType();
1798 if (llvm::isa<RankedTensorType>(initType))
1799 result.addTypes(initType);
1800 }
1801
1802 if (bodyBuild)
1803 buildGenericRegion(builder, result.location, *result.regions.front(),
1804 inputs, inits, bodyBuild);
1805}
1806
1807SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1808 int64_t inputRank =
1809 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1810 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1811 utils::IteratorType::parallel);
1812 for (int64_t reductionDim : getDimensions())
1813 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1814 return iteratorTypes;
1815}
1816
1817ArrayAttr ReduceOp::getIndexingMaps() {
1818 int64_t inputRank =
1819 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1820 SmallVector<AffineMap> affineMaps(
1821 getNumDpsInputs(),
1823 AffineMap resultMap =
1825 .dropResults(getDimensions());
1826 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1827 affineMaps.push_back(resultMap);
1828 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1829}
1830
1831void ReduceOp::getEffects(
1832 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1833 &effects) {
1834 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1835}
1836
1837Speculation::Speculatability ReduceOp::getSpeculatability() {
1838 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1839}
1840
1841static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1842 NamedAttrList &attributes,
1843 StringRef attributeName) {
1844 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1845 return failure();
1846
1847 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1848 return success();
1849}
1850
1851ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1852 std::optional<OperationName> payloadOpName;
1853 NamedAttrList payloadOpAttrs;
1854 if (succeeded(parser.parseOptionalLBrace())) {
1855 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1856 if (failed(operationName))
1857 return failure();
1858 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1859 return failure();
1860 payloadOpName = operationName.value();
1861 if (parser.parseRBrace())
1862 return failure();
1863 }
1864
1865 if (parseDstStyleOp(
1866 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1867 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1868 }))
1869 return failure();
1870
1871 if (payloadOpName.has_value()) {
1872 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1873 ArrayRef(result.operands), /*initFirst=*/true);
1874 } else {
1875 SmallVector<OpAsmParser::Argument> regionArgs;
1876 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1877 /*allowType=*/true, /*allowAttrs=*/true)) {
1878 return failure();
1879 }
1880
1881 Region *body = result.addRegion();
1882 if (parser.parseRegion(*body, regionArgs))
1883 return failure();
1884 }
1885
1886 return success();
1887}
1888
1889static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1890 ArrayRef<int64_t> attributeValue) {
1891 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1892}
1893
1894void ReduceOp::print(OpAsmPrinter &p) {
1895 Block *mapper = getBody();
1896 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1897 if (useShortForm) {
1898 printShortForm(p, &mapper->getOperations().front());
1899 }
1900
1901 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1902 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1903 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1904 if (!useShortForm) {
1905 // Print region if the payload op was not detected.
1906 p.increaseIndent();
1907 p.printNewline();
1908 p << "(";
1909 llvm::interleaveComma(mapper->getArguments(), p,
1910 [&](auto arg) { p.printRegionArgument(arg); });
1911 p << ") ";
1912
1913 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1914 p.decreaseIndent();
1915 }
1916}
1917
1918LogicalResult ReduceOp::verify() {
1919 ArrayRef<int64_t> dimensionsRef = getDimensions();
1920
1921 // The ReduceOp uses `SameVariadicOperandSize`, which requires equal numbers
1922 // of inputs and inits. Detect a mismatch early: when they differ, the
1923 // ODS-generated getInputs()/getInits() accessors compute each group's size
1924 // via floordiv of the total operand count, producing incorrect slices that
1925 // would cause out-of-bounds accesses below.
1926 if (getInputs().size() != static_cast<size_t>(getNumDpsInputs()))
1927 return emitOpError()
1928 << "expected equal number of inputs and outputs (required by "
1929 "SameVariadicOperandSize), got "
1930 << getNumDpsInputs() << " input(s) and " << getNumDpsInits()
1931 << " output(s)";
1932
1933 if (getInputs().empty())
1934 return emitOpError() << "expected at least one input";
1935
1936 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1937 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1938 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1939 return emitOpError() << "expects all inputs to have the same shapes. "
1940 "Shape at input-index "
1941 << i
1942 << " is not equal to the shape at input-index 0.";
1943 }
1944 }
1945 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1946 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1947 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1948 return emitOpError() << "expects all outputs to have the same shapes. "
1949 "Shape at output-index "
1950 << i
1951 << " is not equal to the shape at output-index 0.";
1952 }
1953 }
1954 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1955 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1956
1957 DenseSet<int64_t> dimensionsToReduce;
1958 for (int64_t dimension : dimensionsRef) {
1959 if (dimension < 0 || dimension >= inputType.getRank()) {
1960 return emitOpError()
1961 << "dimensions for reduction should be in the range [0, "
1962 << inputType.getRank() - 1 << "].";
1963 }
1964 dimensionsToReduce.insert(dimension);
1965 }
1966
1967 auto inputDims = inputType.getShape();
1968 auto initDims = initType.getShape();
1969
1970 // Input dimensions that will be left after the reduction.
1971 SmallVector<int64_t> reducedInputDims;
1972 for (const auto &en : llvm::enumerate(inputDims)) {
1973 if (!dimensionsToReduce.count(en.index()))
1974 reducedInputDims.push_back(en.value());
1975 }
1976
1977 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1978 return emitOpError() << "number of dimensions after reduction "
1979 << reducedInputDims.size()
1980 << " doesn't match the init rank "
1981 << initType.getRank();
1982 }
1983
1984 if (reducedInputDims != initDims)
1985 return emitOpError() << "init dimensions [" << initDims
1986 << "] doesn't match input dimensions after reduction ["
1987 << reducedInputDims << "]";
1988
1989 Block *block = getBody();
1990 if (block->getNumArguments() != this->getNumOperands())
1991 return emitOpError()
1992 << "mismatching number of operands and block arguments";
1993
1994 // Check that the first block arguments match the element type of the inputs.
1995 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1996 Type inputElementType =
1997 llvm::cast<ShapedType>(input.getType()).getElementType();
1998 if (inputElementType != bbArg.getType())
1999 return emitOpError()
2000 << "input element type " << inputElementType
2001 << " does not match corresponding block argument type "
2002 << bbArg.getType();
2003 }
2004
2005 // Check that the last block arguments match the element type of the outputs.
2006 for (auto [output, bbArg] : llvm::zip(
2007 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
2008 auto outputElementType =
2009 llvm::cast<ShapedType>(output.getType()).getElementType();
2010 if (outputElementType != bbArg.getType())
2011 return emitOpError()
2012 << "output element type " << outputElementType
2013 << " does not match corresponding block argument type "
2014 << bbArg.getType();
2015 }
2016 return success();
2017}
2018
2019//===----------------------------------------------------------------------===//
2020// TransposeOp
2021//===----------------------------------------------------------------------===//
2022
2023static void buildIdentityRegion(OpBuilder &builder, Location loc,
2024 Region &region, ValueRange inputs,
2025 ValueRange outputs) {
2026 buildGenericRegion(builder, loc, region, inputs, outputs,
2027 [](OpBuilder &b, Location loc, ValueRange args) {
2028 if (!args.empty())
2029 linalg::YieldOp::create(b, loc, args[0]);
2030 });
2031}
2032
2033void TransposeOp::build(::mlir::OpBuilder &builder,
2034 ::mlir::OperationState &result, Value input, Value init,
2035 DenseI64ArrayAttr permutation,
2036 ArrayRef<NamedAttribute> attributes) {
2037 result.addOperands(input);
2038 result.addOperands(init);
2039 result.addAttribute(getPermutationAttrName(result.name), permutation);
2040 result.addAttributes(attributes);
2041
2042 // Add output types for `RankedTensorType` output arguments.
2043 Type initType = init.getType();
2044 if (llvm::isa<RankedTensorType>(initType))
2045 result.addTypes(initType);
2046
2047 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2048 init);
2049}
2050
2051void TransposeOp::build(::mlir::OpBuilder &builder,
2052 ::mlir::OperationState &result, Value input, Value init,
2053 ArrayRef<int64_t> permutation,
2054 ArrayRef<NamedAttribute> attributes) {
2055 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2056 attributes);
2057}
2058
2059ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2061 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2062 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2063 })))
2064 return failure();
2065
2066 OpBuilder builder(parser.getContext());
2067 buildIdentityRegion(builder, result.location, *result.addRegion(),
2068 /*inputs=*/result.operands,
2069 /*outputs=*/{});
2070 return success();
2071}
2072
2073void TransposeOp::getAsmResultNames(
2074 function_ref<void(Value, StringRef)> setNameFn) {
2075 if (!getResults().empty())
2076 setNameFn(getResults().front(), "transposed");
2077}
2078
2079void TransposeOp::print(OpAsmPrinter &p) {
2080 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2081 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2082 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2083}
2084
2085LogicalResult TransposeOp::verify() {
2086 ArrayRef<int64_t> permutationRef = getPermutation();
2087
2088 if (!isPermutationVector(permutationRef))
2089 return emitOpError("permutation is not valid");
2090
2091 auto inputType = getInput().getType();
2092 auto initType = getInit().getType();
2093
2094 int64_t rank = inputType.getRank();
2095
2096 if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
2097 "init")))
2098 return failure();
2099
2100 if (rank != static_cast<int64_t>(permutationRef.size()))
2101 return emitOpError() << "size of permutation " << permutationRef.size()
2102 << " does not match the argument rank " << rank;
2103
2104 auto inputDims = inputType.getShape();
2105 auto initDims = initType.getShape();
2106
2107 for (int64_t i = 0; i < rank; ++i) {
2108 int64_t inputDim = inputDims[permutationRef[i]];
2109 int64_t initDim = initDims[i];
2110
2111 if (inputDim != initDim) {
2112 return emitOpError() << "dim(result, " << i << ") = " << initDim
2113 << " doesn't match dim(input, permutation[" << i
2114 << "]) = " << inputDim;
2115 }
2116 }
2117
2118 return success();
2119}
2120
2121SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2122 int64_t rank = getInit().getType().getRank();
2123 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2124}
2125
2126ArrayAttr TransposeOp::getIndexingMaps() {
2127 Builder builder(getContext());
2128 int64_t rank = getInit().getType().getRank();
2129 return builder.getAffineMapArrayAttr(
2131 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2132 builder.getMultiDimIdentityMap(rank)});
2133}
2134
2135void TransposeOp::getEffects(
2136 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2137 &effects) {
2138 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2139}
2140
2141Speculation::Speculatability TransposeOp::getSpeculatability() {
2142 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2143}
2144
2145LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2146 SmallVectorImpl<OpFoldResult> &result) {
2147 // Only the tensor type is supported.
2148 if (!isa<TensorType>(getInput().getType()))
2149 return failure();
2150
2151 // Single dimension transpose.
2152 if (getPermutation().empty()) {
2153 result.push_back(getInput());
2154 return success();
2155 }
2156 // Identity permutation.
2157 if (isIdentityPermutation(getPermutation())) {
2158 result.push_back(getInput());
2159 return success();
2160 }
2161
2162 return failure();
2163}
2164
2165/// Fold transpose with transpose.
2166struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2167 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2168
2169 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2170 PatternRewriter &rewriter) const override {
2171 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2172 if (!defTransposeOp)
2173 return failure();
2174 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2175 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2176 SmallVector<int64_t> foldedPerms;
2177 foldedPerms.reserve(perms.size());
2178 for (int64_t perm : perms)
2179 foldedPerms.push_back(defPerms[perm]);
2180
2181 rewriter.replaceOpWithNewOp<TransposeOp>(
2182 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2183 foldedPerms);
2184 return success();
2185 }
2186};
2187
2188/// Rewrite a transpose of a dense splat constant into a dense splat constant of
2189/// the transposed output shape.
2190struct FoldTransposeSplatConstant : OpRewritePattern<linalg::TransposeOp> {
2191 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2192
2193 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2194 PatternRewriter &rewriter) const override {
2195 if (!transposeOp.hasPureTensorSemantics())
2196 return failure();
2197
2198 auto splatValue =
2199 getScalarConstantAttrFromDenseSplat(transposeOp.getInput());
2200 if (!splatValue.has_value())
2201 return failure();
2202
2203 auto resultType =
2204 cast<RankedTensorType>(transposeOp.getResult()[0].getType());
2205
2206 auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
2207 rewriter.replaceOpWithNewOp<arith::ConstantOp>(transposeOp, resultType,
2208 resultAttr);
2209 return success();
2210 }
2211};
2212
2213/// This pattern canonicalize transpose by swapping the order of
2214/// broadcast and transpose:
2215/// transpose(broadcast(input)) -> broadcast(transpose(input))
2216struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2217 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2218
2219 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2220 PatternRewriter &rewriter) const override {
2221 Value input = transposeOp.getInput();
2222 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2223 if (!input.hasOneUse() || !broadcastOp)
2224 return failure();
2225
2226 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2227 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2228
2229 // Get new perms and new dimensions.
2230 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2232 SmallVector<int64_t> resultDimensions;
2233 unsigned dimensionSize = dimensions.size();
2234 for (unsigned i = 0; i < dimensionSize; ++i)
2235 resultDimensions.push_back(invertPerm[dimensions[i]]);
2236
2237 // Create transpose result.
2238 Value broadcastInput = broadcastOp.getInput();
2239 Location loc = transposeOp.getLoc();
2240 MLIRContext *ctx = transposeOp.getContext();
2242 auto broadcastInputTy =
2243 mlir::cast<RankedTensorType>(broadcastInput.getType());
2244 unsigned inputRank = broadcastInputTy.getRank();
2245 for (unsigned i = 0; i < inputRank; ++i) {
2246 if (broadcastInputTy.isDynamicDim(i)) {
2247 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2248 ->getResult(0));
2249 } else {
2250 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2251 broadcastInputTy.getDimSize(i)));
2252 }
2253 }
2254 SmallVector<OpFoldResult> transposeResultShapes =
2255 applyPermutation(dims, resultPerms);
2256 Value transposeInit = tensor::EmptyOp::create(
2257 rewriter, transposeOp.getLoc(), transposeResultShapes,
2258 broadcastInputTy.getElementType());
2259
2260 // Create broadcast(transpose(input)).
2261 Value transposeResult =
2262 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2263 transposeInit, resultPerms)
2264 ->getResult(0);
2265 rewriter.replaceOpWithNewOp<BroadcastOp>(
2266 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2267 return success();
2268 }
2269};
2270
2271void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2272 MLIRContext *context) {
2273 results.add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
2274 SwapTransposeWithBroadcast>(context);
2275}
2276
2277//===----------------------------------------------------------------------===//
2278// BroadcastOp
2279//===----------------------------------------------------------------------===//
2280
2281void BroadcastOp::build(::mlir::OpBuilder &builder,
2282 ::mlir::OperationState &result, Value input, Value init,
2283 DenseI64ArrayAttr dimensions,
2284 ArrayRef<NamedAttribute> attributes) {
2285 result.addOperands(input);
2286 result.addOperands(init);
2287 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2288 result.addAttributes(attributes);
2289
2290 // Add output types for `RankedTensorType` output arguments.
2291 Type initType = init.getType();
2292 if (llvm::isa<RankedTensorType>(initType))
2293 result.addTypes(initType);
2294
2295 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2296 init);
2297}
2298
2299void BroadcastOp::build(::mlir::OpBuilder &builder,
2300 ::mlir::OperationState &result, Value input, Value init,
2301 ArrayRef<int64_t> dimensions,
2302 ArrayRef<NamedAttribute> attributes) {
2303 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2304 attributes);
2305}
2306
2307ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2309 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2310 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2311 })))
2312 return failure();
2313
2314 OpBuilder builder(parser.getContext());
2315 buildIdentityRegion(builder, result.location, *result.addRegion(),
2316 /*inputs=*/result.operands,
2317 /*outputs=*/{});
2318 return success();
2319}
2320
2321void BroadcastOp::getAsmResultNames(
2322 function_ref<void(Value, StringRef)> setNameFn) {
2323 if (!getResults().empty())
2324 setNameFn(getResults().front(), "broadcasted");
2325}
2326
2327void BroadcastOp::print(OpAsmPrinter &p) {
2328 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2329 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2330 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2331}
2332
2333LogicalResult BroadcastOp::verify() {
2334 ArrayRef<int64_t> dimensionsRef = getDimensions();
2335
2336 auto inputType = getInput().getType();
2337 auto initType = getInit().getType();
2338
2339 int64_t inputRank = inputType.getRank();
2340 int64_t initRank = initType.getRank();
2341
2342 auto inputShape = inputType.getShape();
2343 auto initShape = initType.getShape();
2344
2345 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2346 return emitOpError() << "input rank plus added dimensions does not "
2347 "match init rank. input rank: "
2348 << inputRank
2349 << ", dimensions size: " << dimensionsRef.size()
2350 << ", init rank: " << initRank;
2351
2352 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2353 if (dim < 0 || dim >= initRank)
2354 return emitOpError() << "dimension " << idx
2355 << " is out of range. expected range: [0, "
2356 << initRank - 1 << "], got: " << dim;
2357 }
2358
2359 // Mapping from input dims to init dims.
2360 SmallVector<int64_t> dimMap;
2361 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2362 if (!llvm::is_contained(dimensionsRef, dim))
2363 dimMap.push_back(dim);
2364 }
2365
2366 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2367 // This dimensions is mapped from the input. Init and input dims should
2368 // match.
2369 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2370 return emitOpError() << "input dim " << inputDimIdx
2371 << " should match init dim " << initDimIdx
2372 << ". input: " << inputShape[inputDimIdx]
2373 << ", init: " << initShape[initDimIdx];
2374 }
2375
2376 return success();
2377}
2378
2379SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2380 int64_t rank = getInit().getType().getRank();
2381 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2382}
2383
2384ArrayAttr BroadcastOp::getIndexingMaps() {
2385 Builder builder(getContext());
2386 int64_t rank = getInit().getType().getRank();
2387 return builder.getAffineMapArrayAttr(
2388 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2389 builder.getMultiDimIdentityMap(rank)});
2390}
2391
2392void BroadcastOp::getEffects(
2393 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2394 &effects) {
2395 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2396}
2397
2398Speculation::Speculatability BroadcastOp::getSpeculatability() {
2399 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2400}
2401
2402/// Fold back-to-back broadcasts together.
2403struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2404 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2405
2406 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2407 PatternRewriter &rewriter) const override {
2408 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2409 if (!defBroadcastOp)
2410 return failure();
2411 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2412 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2413 SmallVector<int64_t> foldedDims(dimensions);
2414 Value init = broadcastOp.getInit();
2415 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2416 // Mapping from input dims to init dims.
2417 SmallVector<int64_t> dimMap;
2418 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2419 if (!llvm::is_contained(dimensions, dim))
2420 dimMap.push_back(dim);
2421 }
2422 for (auto dim : defDimensions)
2423 foldedDims.push_back(dimMap[dim]);
2424
2425 llvm::sort(foldedDims);
2426 rewriter.replaceOpWithNewOp<BroadcastOp>(
2427 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2428 return success();
2429 }
2430};
2431
2432/// Rewrite a broadcast of a dense splat constant into a dense splat constant of
2433/// the broadcast output shape.
2434struct FoldBroadcastSplatConstant : OpRewritePattern<linalg::BroadcastOp> {
2435 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2436
2437 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2438 PatternRewriter &rewriter) const override {
2439 if (!broadcastOp.hasPureTensorSemantics())
2440 return failure();
2441
2442 auto splatValue =
2443 getScalarConstantAttrFromDenseSplat(broadcastOp.getInput());
2444
2445 if (!splatValue.has_value())
2446 return failure();
2447
2448 auto resultType =
2449 cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
2450 if (!resultType.hasStaticShape())
2451 return rewriter.notifyMatchFailure(broadcastOp,
2452 "result type has dynamic shape");
2453
2454 auto resultAttr = DenseElementsAttr::get(resultType, splatValue.value());
2455 rewriter.replaceOpWithNewOp<arith::ConstantOp>(broadcastOp, resultType,
2456 resultAttr);
2457 return success();
2458 }
2459};
2460
2461void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2462 MLIRContext *context) {
2463 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
2464 FoldBroadcastSplatConstant>(context);
2465}
2466
2467//===----------------------------------------------------------------------===//
2468// YieldOp
2469//===----------------------------------------------------------------------===//
2470
2471void linalg::YieldOp::print(OpAsmPrinter &p) {
2472 if (getNumOperands() > 0)
2473 p << ' ' << getOperands();
2474 p.printOptionalAttrDict((*this)->getAttrs());
2475 if (getNumOperands() > 0)
2476 p << " : " << getOperandTypes();
2477}
2478
2479ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2480 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2481 SmallVector<Type, 2> types;
2482 SMLoc loc = parser.getCurrentLocation();
2483 return failure(parser.parseOperandList(opInfo) ||
2484 parser.parseOptionalAttrDict(result.attributes) ||
2485 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2486 parser.resolveOperands(opInfo, types, loc, result.operands));
2487}
2488
2489// Check the operand number and types must match the element types of the
2490// LinalgOp interface's shaped operands.
2491static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2492 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2493 return op.emitOpError("expected number of yield values (")
2494 << op.getNumOperands()
2495 << ") to match the number of inits / outs operands of the enclosing "
2496 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2497
2498 for (OpOperand &opOperand : op->getOpOperands()) {
2499 OpOperand *outputOperand =
2500 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2501 Type elementType = outputOperand->get().getType();
2502 if (isa<MemRefType, RankedTensorType>(elementType))
2503 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2504 if (opOperand.get().getType() != elementType)
2505 return op.emitOpError("type of yield operand ")
2506 << (opOperand.getOperandNumber() + 1) << " ("
2507 << opOperand.get().getType() << ") doesn't match "
2508 << "the element type of the enclosing linalg.generic op ("
2509 << elementType << ")";
2510 }
2511 return success();
2512}
2513
2514LogicalResult linalg::YieldOp::verify() {
2515 auto *parentOp = (*this)->getParentOp();
2516 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2517 return emitOpError("expected single non-empty parent region");
2518
2519 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2520 return verifyYield(*this, linalgOp);
2521
2522 return emitOpError("expected parent op with LinalgOp interface");
2523}
2524
2525//===----------------------------------------------------------------------===//
2526// IndexOp
2527//===----------------------------------------------------------------------===//
2528
2529LogicalResult IndexOp::verify() {
2530 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2531 if (!linalgOp)
2532 return emitOpError("expected parent op with LinalgOp interface");
2533 if (linalgOp.getNumLoops() <= getDim())
2534 return emitOpError("expected dim (")
2535 << getDim() << ") to be lower than the number of loops ("
2536 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2537 return success();
2538}
2539
2540OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2541 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2542 // Bail out if `linalg.index` does not have a proper parent yet at this
2543 // point, e.g., when calling `createOrFold` during IR construction in
2544 // `genericOp::build`.
2545 if (!linalgOp)
2546 return OpFoldResult{};
2547
2548 // Index of unit dims is always 0.
2549 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2550 uint64_t dim = getDim();
2551 assert(dim < loopBounds.size() && "Dim is out of bounds");
2552 if (loopBounds[dim] == 1)
2553 return IntegerAttr::get(IndexType::get(getContext()), 0);
2554
2555 return OpFoldResult{};
2556}
2557
2558/////// Operations corresponding to library calls defined with Tablegen ////////
2559
2560#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2561
2562#define GET_OP_CLASSES
2563#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2564
2565#define GET_OP_CLASSES
2566#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2567#define GET_OP_CLASSES
2568#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2569
2570AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2571 unsigned rank,
2572 MLIRContext *context) {
2573 if (maybeMap)
2574 return *maybeMap;
2575 if (rank == 0)
2576 return AffineMap::get(context);
2577 return AffineMap::getMultiDimIdentityMap(rank, context);
2578}
2579
2581mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2582 MLIRContext *context) {
2584 res.reserve(num);
2585 for (unsigned i = 0; i < num; ++i)
2586 res.push_back(getAffineDimExpr(startIdx++, context));
2587 return res;
2588}
2589
2592 auto rangeA = llvm::make_range(a.begin(), a.end());
2593 auto rangeB = llvm::make_range(b.begin(), b.end());
2594 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2595 return llvm::to_vector<4>(concatRanges);
2596}
2597
2598static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2599 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2600 ss << "view";
2601 for (auto size : memref.getShape())
2602 if (size < 0)
2603 ss << "sx";
2604 else
2605 ss << size << "x";
2606 if (failed(appendMangledType(ss, memref.getElementType())))
2607 return failure();
2608 if (auto as = memref.getMemorySpace()) {
2609 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2610 ss << "as" << attr.getInt();
2611 else
2612 return failure();
2613 }
2614 return success();
2615 }
2616 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2617 ss << "vector";
2618 llvm::interleave(
2619 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2620 if (failed(appendMangledType(ss, vec.getElementType())))
2621 return failure();
2622 return success();
2623 }
2625 ss << t;
2626 return success();
2627 }
2628 return failure();
2629}
2630
2632 assert(isa<LinalgOp>(op));
2633 std::string name(op->getName().getStringRef().str());
2634 std::string fun = "";
2635 for (NamedAttribute kv : op->getAttrs()) {
2636 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2637 fun = stringifyEnum(ufa.getValue()).str() + "_";
2638 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2639 fun = stringifyEnum(bfa.getValue()).str() + "_";
2640 }
2641 }
2642 name.reserve(128);
2643 llvm::replace(name, '.', '_');
2644 llvm::raw_string_ostream ss(name);
2645 ss << "_" << fun;
2646 for (Type t : op->getOperandTypes()) {
2647 if (failed(appendMangledType(ss, t)))
2648 return std::string();
2649 ss << "_";
2650 }
2651 name.pop_back();
2652 return name;
2653}
2654
2655//===----------------------------------------------------------------------===//
2656// Canonicalizers and Folders.
2657//===----------------------------------------------------------------------===//
2658
2659namespace {
2660struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2662
2663 LogicalResult matchAndRewrite(LinalgOp op,
2664 PatternRewriter &rewriter) const override {
2665 for (OpOperand &opOperand : op->getOpOperands()) {
2666 // Linalg "inputs" may be either tensor or memref type.
2667 // tensor<0xelt_type> is a convention that may not always mean
2668 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2669 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2670 if (!mt)
2671 continue;
2672 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2673 rewriter.eraseOp(op);
2674 return success();
2675 }
2676 }
2677 return failure();
2678 }
2679};
2680
2681/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2682/// result that is more static than the linalg op.
2683struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2684 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2685
2686 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2687 PatternRewriter &rewriter) const override {
2688 if (!tensor::canFoldIntoProducerOp(castOp))
2689 return failure();
2690
2691 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2692 if (!linalgOp)
2693 return failure();
2694
2695 // Cast can be in conditionally reachable region, if which case folding will
2696 // generate invalid code. Only conservatively fold ops in same block for
2697 // now.
2698 if (castOp->getBlock() != linalgOp->getBlock())
2699 return failure();
2700
2701 OpBuilder::InsertionGuard guard(rewriter);
2702 rewriter.setInsertionPoint(linalgOp);
2703
2704 Location loc = linalgOp.getLoc();
2705 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2706 unsigned resultNumber = resultValue.getResultNumber();
2707 auto resultType =
2708 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2709 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2710 // going from a more dynamic shape to a less dynamic shape. If the producer
2711 // for this cast, i.e. producer of the out operand, is also an operation
2712 // that folds with tensor.cast consumer (like this pattern), the cast will
2713 // continue to propagate as far up the stack as it can go.
2714 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2715 Value newOperand =
2716 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2717 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2718 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2719 linalgOp.getDpsInits().end());
2720 outputOperands[resultNumber] = newOperand;
2721 newOperands.append(outputOperands.begin(), outputOperands.end());
2722
2723 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2724 linalgOp->result_type_end());
2725 resultTypes[resultNumber] = resultType;
2726 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2727
2728 // Create a tensor.cast operation back to the original type.
2729 Value castBack = tensor::CastOp::create(
2730 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2731
2732 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2733 results[resultNumber] = castBack;
2734 rewriter.replaceOp(linalgOp, results);
2735 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2736 return success();
2737 }
2738};
2739
2740/// For each of the operand in `operands` this function maps the static sizes of
2741/// dimensions to their affine dim expressions.
2742static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2743 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2744 for (OpOperand &opOperand : operands) {
2745 if (linalgOp.isScalar(&opOperand))
2746 continue;
2747 Value src = opOperand.get();
2748 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2749 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2750
2751 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2752 // `tensor.cast` operation and source of the cast operation has a static
2753 // shape, then assign it to the `sourceShape`.
2754 auto *parentOp = src.getDefiningOp();
2755 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2756 if (parentOp) {
2757 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2758 Value castSource = castOp.getSource();
2759 auto castSourceType =
2760 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2761 if (castSourceType && castSourceType.hasStaticShape())
2762 sourceShape = castSourceType.getShape();
2763 }
2764 }
2765
2766 // If the source shape's dimension has a static shape, map the affine dim
2767 // expression to the known static size.
2768 for (unsigned i = 0; i < sourceShape.size(); i++) {
2769 if (sourceType.isDynamicDim(i))
2770 continue;
2771 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2772 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2773 }
2774 }
2775}
2776
2777/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2778/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2779/// their result types is stored in `resultTypes`. If `opOperand` requires no
2780/// change then `changeNeeded` is false and same operand is added in the
2781/// `newOperands` list.
2782static void createNewOperandWithStaticSizes(
2783 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2784 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2785 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2786 bool &changeNeeded) {
2787 Value src = opOperand->get();
2788 newOperands.push_back(src);
2789 if (linalgOp.isScalar(opOperand))
2790 return;
2791 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2792 Type resultType = sourceType;
2793 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2794 resultTypes.push_back(resultType);
2795 return;
2796 }
2797 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2798 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2799 SmallVector<int64_t> newShape;
2800 // If operand is updated with new shape, `newOperandNeeded` will be
2801 // true.
2802 bool newOperandNeeded = false;
2803 for (unsigned i = 0; i < sourceShape.size(); i++) {
2804 int64_t dimShape = sourceShape[i];
2805 AffineExpr dimExpr = sourceMap.getResult(i);
2806 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2807 newShape.push_back(dimShape);
2808 continue;
2809 }
2810 // Dimension has a dynamic shape and corresponding affine dim
2811 // expression is present in the map. So assign the size for the
2812 // given affine dim expression to the dimension.
2813 newShape.push_back(affineExprToSize[dimExpr]);
2814 newOperandNeeded = true;
2815 }
2816 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2817 sourceType.getEncoding());
2818 if (newOperandNeeded) {
2819 changeNeeded = true;
2820 // Get the new operand value given its size and element type by
2821 // casting it.
2822 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2823 unsigned index = opOperand->getOperandNumber();
2824 newOperands[index] = newOperand;
2825 }
2826 if (linalgOp.isDpsInit(opOperand))
2827 resultTypes.push_back(resultType);
2828}
2829
2830/// Static shapes for the operands can be inferred if any one of the operands
2831/// have a static shape. This can be done by referring to the affine dim
2832/// expressions for the operand.
2833struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2834 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2835
2836 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2837 PatternRewriter &rewriter) const override {
2838 if (!linalgOp.hasPureTensorSemantics())
2839 return failure();
2840
2841 // Maps must be projected permutations.
2842 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2843 return !map.isProjectedPermutation();
2844 }))
2845 return failure();
2846
2847 // Maps affine dim expressions to the static size of that dimension.
2848 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2849 Location loc = linalgOp.getLoc();
2850
2851 // For each of the affine dim expression, check if the size is known. If
2852 // known add that in the map.
2853 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2854
2855 SmallVector<Value> newOperands;
2856 SmallVector<Type> resultTypes;
2857
2858 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2859 // change in their types.
2860 bool changeNeeded = false;
2861 newOperands.reserve(linalgOp->getNumOperands());
2862 resultTypes.reserve(linalgOp.getNumDpsInits());
2863
2864 // Iterate over all the operands and update the static sizes.
2865 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2866 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2867 affineExprToSize, linalgOp, newOperands,
2868 resultTypes, changeNeeded);
2869 }
2870
2871 // If the generic op has all the required static information, no
2872 // canonicalization needed.
2873 if (!changeNeeded)
2874 return failure();
2875
2876 // Clone op.
2877 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2878 SmallVector<Value> replacements;
2879 replacements.reserve(newOp->getNumResults());
2880 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2881 Value newResult = std::get<1>(it);
2882 Value oldResult = std::get<0>(it);
2883 Type newType = newResult.getType();
2884 Type oldType = oldResult.getType();
2885 replacements.push_back(
2886 (newType != oldType)
2887 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2888 : newResult);
2889 }
2890 rewriter.replaceOp(linalgOp, replacements);
2891 return success();
2892 }
2893};
2894
2895} // namespace
2896
2897// All named ops canonicalizers and folders are auto-generated in the
2898// .cpp.inc.
2899
2900//===----------------------------------------------------------------------===//
2901// SoftmaxOp
2902//===----------------------------------------------------------------------===//
2903
2904LogicalResult SoftmaxOp::verify() {
2905 ShapedType inputType = getInputOperandType();
2906 ShapedType outputType = getOutputOperandType();
2907
2908 ArrayRef<int64_t> inputShape = inputType.getShape();
2909 ArrayRef<int64_t> outputShape = outputType.getShape();
2910 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2911 return emitOpError("incompatible output shape");
2912
2913 int64_t inputRank = getInputOperandRank();
2914 int64_t dimension = getDimension();
2915 if ((dimension < 0) || (dimension >= inputRank))
2916 return emitOpError("incorrect dimension specified");
2917
2918 return success();
2919}
2920
2921SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2922 int64_t operandRank = getInputOperandRank();
2923 SmallVector<Range> loopBounds(operandRank);
2924 Location loc = getLoc();
2925 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2926 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2927 Value source = getInput();
2928 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2929 loopBounds[dim].offset = zero;
2930 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2931 loopBounds[dim].stride = one;
2932 }
2933 return loopBounds;
2934}
2935
2936SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2937 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2938 utils::IteratorType::parallel);
2939 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2940 return iteratorTypes;
2941}
2942
2943FailureOr<TilingResult>
2944SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2945 ArrayRef<OpFoldResult> offsets,
2946 ArrayRef<OpFoldResult> sizes) {
2947 int64_t rank = getInputOperandRank();
2948 auto oneAttr = builder.getI64IntegerAttr(1);
2949 SmallVector<OpFoldResult> strides(rank, oneAttr);
2950 SmallVector<Value> tiledOperands;
2951 Operation *inputSlice =
2952 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2953 if (!inputSlice) {
2954 return emitOpError("failed to compute input slice");
2955 }
2956 tiledOperands.emplace_back(inputSlice->getResult(0));
2957 Operation *outputSlice =
2958 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2959 if (!outputSlice) {
2960 return emitOpError("failed to compute output slice");
2961 }
2962 tiledOperands.emplace_back(outputSlice->getResult(0));
2963
2964 SmallVector<Type, 4> resultTypes;
2965 if (hasPureTensorSemantics())
2966 resultTypes.push_back(tiledOperands[1].getType());
2967 Operation *tiledOp =
2968 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2969
2970 return TilingResult{
2971 {tiledOp},
2972 SmallVector<Value>(tiledOp->getResults()),
2973 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2974}
2975
2976LogicalResult SoftmaxOp::getResultTilePosition(
2977 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2978 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2979 SmallVector<OpFoldResult> &resultSizes) {
2980 if (resultNumber == 0) {
2981 resultOffsets.assign(offsets.begin(), offsets.end());
2982 resultSizes.assign(sizes.begin(), sizes.end());
2983 return success();
2984 }
2985 return failure();
2986}
2987
2988// cast(dynamic) -> static.
2989LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2990 return memref::foldMemRefCast(*this);
2991}
2992
2993LogicalResult
2994SoftmaxOp::reifyResultShapes(OpBuilder &b,
2995 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2996 SmallVector<OpFoldResult> shapes;
2997 Location loc = getOperation()->getLoc();
2998 IRRewriter rewriter(b);
2999 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
3000 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
3001 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
3002 if (!outputShapedType.isDynamicDim(dim)) {
3003 // Static dim: Return IntegerAttr.
3004 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
3005 } else {
3006 // Dynamic dim: Return Value.
3007 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
3008 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
3009 }
3010 }
3011 reifiedReturnShapes.emplace_back(std::move(shapes));
3012 return success();
3013}
3014
3015void SoftmaxOp::getEffects(
3016 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3017 &effects) {
3018 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
3019 if (!llvm::isa<MemRefType>(operand.getType()))
3020 continue;
3021 effects.emplace_back(MemoryEffects::Read::get(),
3022 &getOperation()->getOpOperand(index), /*stage=*/0,
3023 /*effectOnFullRegion=*/true,
3025 }
3026
3027 for (OpOperand &operand : getDpsInitsMutable()) {
3028 if (!llvm::isa<MemRefType>(operand.get().getType()))
3029 continue;
3030 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
3031 /*effectOnFullRegion=*/true,
3033 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
3034 /*effectOnFullRegion=*/true,
3036 }
3037}
3038
3039// Helper functions for softmax decomposition.
3040// @{
3041
3042// Helper function to produce the iterator types (reduction or parallel) and
3043// affine maps for the iterators used in the decomposition of softmax.
3044// This method creates:
3045// If allParallel == true:
3046// - iterator type: {parallel, ..., parallel}
3047// - affine maps:
3048// -- identity with inputRank dimensions.
3049// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
3050// where N == inputRank.
3051//
3052// If allParallel == false:
3053// - iterator type at dim(i) == parallel for i != \p dim and
3054// dim(dim) == reduction.
3055// - affine map:
3056// -- identity with inputRank dimensions.
3057// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
3058// where N == inputRank.
3059static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
3061 int64_t dim, bool allParallel = false) {
3062 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
3063 utils::IteratorType::parallel);
3064 if (!allParallel)
3065 iteratorTypes[dim] = utils::IteratorType::reduction;
3066 MLIRContext *ctxt = builder.getContext();
3067 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
3068 SmallVector<AffineExpr, 2> affineExprs;
3069 for (int i = 0; i < inputRank; i++) {
3070 if (i != dim)
3071 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
3072 }
3073 auto reductionMap =
3074 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
3075 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
3076 return std::make_tuple(iteratorTypes, indexingMaps);
3077}
3078
3079// Helper function to produce a linalg.generic that computes a reduction on
3080// dimension \p dim with the operation type \p T.
3081template <typename T>
3082static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
3083 int64_t dim) {
3084 auto inputType = cast<ShapedType>(input.getType());
3085 ArrayRef<int64_t> inputShape = inputType.getShape();
3086 int64_t inputRank = inputShape.size();
3087 auto [iteratorTypes, indexingMaps] =
3088 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
3089 assert(indexingMaps.size() == 2 &&
3090 "We should have two maps: 1 for the input, 1 for the output");
3091 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3092
3093 auto genericOp = linalg::GenericOp::create(
3094 builder, loc, output.getType(), input, output, indexingMaps,
3095 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
3096 Value result = T::create(b, loc, args[0], args[1]);
3097 linalg::YieldOp::create(b, loc, result);
3098 });
3099 return genericOp.getResult(0);
3100}
3101
3102/// Produce a linalg generic that computes the second step of the softmax
3103/// decomposition: res = exp(input - max), where \p max is the max of \p input
3104/// on dimension \p dim.
3105static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3106 Value max, Value output, int64_t dim) {
3107 auto inputType = cast<ShapedType>(input.getType());
3108 ArrayRef<int64_t> inputShape = inputType.getShape();
3109 int64_t inputRank = inputShape.size();
3110 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3111 builder, inputRank, dim, /*allParallel=*/true);
3112 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3113 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3114 // Add the affine map for the output argument.
3115 indexingMaps.push_back(indexingMaps[0]);
3116 auto genericOp = linalg::GenericOp::create(
3117 builder, loc, input.getType(), ValueRange{input, max}, output,
3118 indexingMaps, iteratorTypes,
3119 [&](OpBuilder &b, Location loc, ValueRange args) {
3120 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3121 Value result = math::ExpOp::create(b, loc, diff);
3122 linalg::YieldOp::create(b, loc, result);
3123 });
3124 return genericOp.getResult(0);
3125}
3126
3127/// Produce a linalg generic that computes the final step of the softmax
3128/// decomposition.
3129/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3130/// yield n / d
3131/// }
3132static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3133 Value denominator, Value output, int64_t dim) {
3134 auto inputType = cast<ShapedType>(numerator.getType());
3135 ArrayRef<int64_t> inputShape = inputType.getShape();
3136 int64_t inputRank = inputShape.size();
3137 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3138 builder, inputRank, dim, /*allParallel=*/true);
3139 assert(indexingMaps.size() == 2 &&
3140 "We should have one map for each input (2)");
3141 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3142 // Add the affine map for the output tensor.
3143 indexingMaps.push_back(indexingMaps[0]);
3144 auto genericOp = linalg::GenericOp::create(
3145 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3146 output, indexingMaps, iteratorTypes,
3147 [&](OpBuilder &b, Location loc, ValueRange args) {
3148 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3149 linalg::YieldOp::create(b, loc, result);
3150 });
3151 return genericOp.getResult(0);
3152}
3153// @} End helper functions for softmax decomposition.
3154
3155/// Given an N-dimensional tensor x, this method converts
3156/// softmax(x) to the following sequence of operations:
3157///
3158/// 1. Compute the max of x along dimension d. This results
3159/// in a N-1 dimensional tensor m.
3160/// m = max(x, dim = d)
3161///
3162/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3163/// a N dimensional tensor z.
3164/// z = exp(x - m)
3165///
3166/// 3. Compute the sum of z along dimension d. This results in
3167/// a N-1 dimensional tensor l.
3168/// l = sum(z, dim = d)
3169///
3170/// 4. Divide z and l. This gives the N-dimensional softmax.
3171/// softmax = z / l
3172///
3173FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3174 OpBuilder::InsertionGuard guard(b);
3175 b.setInsertionPoint(*this);
3176 Location loc = getLoc();
3177 Value input = getInput();
3178 ShapedType inputType = getInputOperandType();
3179 Type elementType = inputType.getElementType();
3180 int64_t reductionDim = getDimension();
3181 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3182 Value output = getOutput();
3183 dims.erase(dims.begin() + reductionDim);
3184 // Step 1: Compute max along dim.
3185 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3186 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3187 elementType, b, loc,
3188 /*useOnlyFiniteValue=*/true);
3189 Value neutralForMaxFInit =
3190 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3191 .result();
3192 Value max =
3193 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3194
3195 // Step 2: Subtract max from input and exponentiate.
3196 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3197
3198 // Step 3: Compute sum along dim.
3199 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3200 b, loc, /*useOnlyFiniteValue=*/true);
3201 Value zeroInit =
3202 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3203 Value denominator =
3204 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3205
3206 // Step 4: Compute softmax.
3207 Value result =
3208 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3209 return SmallVector<Value>{result};
3210}
3211
3212//===----------------------------------------------------------------------===//
3213// WinogradFilterTransformOp
3214//===----------------------------------------------------------------------===//
3215
3216LogicalResult WinogradFilterTransformOp::verify() {
3217 auto filterType = cast<ShapedType>(getFilter().getType());
3218 ArrayRef<int64_t> filterShape = filterType.getShape();
3219 int64_t filterH = filterShape[getFilterHDim()];
3220 int64_t filterW = filterShape[getFilterWDim()];
3221 WinogradConv2DFmr fmr = getFmr();
3222 int64_t m, r;
3223 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3224
3225 if (filterH != r && filterH != 1)
3226 return emitOpError("expect filter height either equals to r or 1");
3227 if (filterW != r && filterW != 1)
3228 return emitOpError("expect filter width either equals to r or 1");
3229 if (filterH == 1 && filterW == 1)
3230 return emitOpError("expect either filter height or width equals to r");
3231
3232 SmallVector<int64_t> expectedOutputShape;
3233 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3234 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3235 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3236 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3237
3238 auto outputType = cast<ShapedType>(getOutput().getType());
3239 ArrayRef<int64_t> outputShape = outputType.getShape();
3240 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3241 return emitOpError("the output shape is not expected");
3242 }
3243 return success();
3244}
3245
3246SmallVector<Range>
3247WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3248 Location loc = getLoc();
3249 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3250 IntegerAttr oneAttr = builder.getIndexAttr(1);
3251 Value filter = getFilter();
3252 int64_t filterRank = getFilterOperandRank();
3253 SmallVector<Range> loopBounds(filterRank);
3254 for (unsigned dim = 0; dim < filterRank; ++dim) {
3255 loopBounds[dim].offset = zeroAttr;
3256 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3257 loopBounds[dim].stride = oneAttr;
3258 }
3259 return loopBounds;
3260}
3261
3262SmallVector<utils::IteratorType>
3263WinogradFilterTransformOp::getLoopIteratorTypes() {
3264 int64_t filterRank = getFilterOperandRank();
3265 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3266 utils::IteratorType::parallel);
3267 return iteratorTypes;
3268}
3269
3270LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3271 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3272 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3273 SmallVector<OpFoldResult> &resultSizes) {
3274 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3275 ShapedType filterType = getFilterOperandType();
3276 ArrayRef<int64_t> filterShape = filterType.getShape();
3277 int64_t filterH = filterShape[getFilterHDim()];
3278 int64_t filterW = filterShape[getFilterWDim()];
3279 WinogradConv2DFmr fmr = getFmr();
3280 int64_t m, r;
3281 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3282 int64_t alpha = m + r - 1;
3283 int64_t alphaH = filterH != 1 ? alpha : 1;
3284 int64_t alphaW = filterW != 1 ? alpha : 1;
3285 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3286 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3287
3288 resultOffsets.append(
3289 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3290 resultSizes.append(
3291 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3292
3293 return success();
3294}
3295
3296/// Implement tiling for winograd_filter_transform
3297/// The input of winograd_filter_transform is (F, KH, KW, C).
3298/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3299/// Users can specify the tile sizes of F and C.
3300/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3301/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3302FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3303 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3304 ArrayRef<OpFoldResult> sizes) {
3305 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3306 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3307 ShapedType filterType = getFilterOperandType();
3308 ArrayRef<int64_t> filterShape = filterType.getShape();
3309 int64_t filterH = filterShape[getFilterHDim()];
3310 int64_t filterW = filterShape[getFilterWDim()];
3311 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3312 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3313 SmallVector<Value> tiledOperands;
3314 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3315
3316 sliceOffsets.append(
3317 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3318 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3319 sizes[getFilterCDim()]});
3320 int64_t filterRank = getFilterOperandRank();
3321 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3322 Location loc = getLoc();
3323 auto filterSlice = tensor::ExtractSliceOp::create(
3324 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3325 tiledOperands.emplace_back(filterSlice);
3326
3327 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3328 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3329 resultSizes)))
3330 return failure();
3331
3332 int64_t outputRank = getOutputOperandRank();
3333 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3334 auto outputSlice = tensor::ExtractSliceOp::create(
3335 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3336 tiledOperands.emplace_back(outputSlice);
3337
3338 SmallVector<Type> resultTypes;
3339 resultTypes.push_back(tiledOperands[1].getType());
3340 Operation *tiledOp =
3341 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3342
3343 return TilingResult{
3344 {tiledOp},
3345 SmallVector<Value>(tiledOp->getResults()),
3346 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3347}
3348
3349//===----------------------------------------------------------------------===//
3350// WinogradInputTransformOp
3351//===----------------------------------------------------------------------===//
3352
3353LogicalResult WinogradInputTransformOp::verify() {
3354 auto inputType = cast<ShapedType>(getInput().getType());
3355 ArrayRef<int64_t> inputShape = inputType.getShape();
3356 int64_t inputH = inputShape[getInputHDim()];
3357 int64_t inputW = inputShape[getInputWDim()];
3358 WinogradConv2DFmr fmr = getFmr();
3359 int64_t m, r;
3360 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3361 int64_t tileSize = m + r - 1;
3362
3363 auto outputType = cast<ShapedType>(getOutput().getType());
3364 ArrayRef<int64_t> outputShape = outputType.getShape();
3365 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3366 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3367
3368 SmallVector<int64_t> expectedOutputShape(6, inputH);
3369 if (ShapedType::isDynamic(inputH)) {
3370 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3371 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3372 } else {
3373 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3374 expectedOutputShape[getOutputTileHDim()] =
3375 leftTransform ? (inputH - (r - 1)) / m : inputH;
3376 }
3377 if (ShapedType::isDynamic(inputW)) {
3378 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3379 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3380 } else {
3381 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3382 expectedOutputShape[getOutputTileWDim()] =
3383 rightTransform ? (inputW - (r - 1)) / m : inputW;
3384 }
3385 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3386 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3387
3388 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3389 return emitOpError("the output shape is not expected");
3390 }
3391 return success();
3392}
3393
3394SmallVector<Range>
3395WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3396 Location loc = getLoc();
3397 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3398 IntegerAttr oneAttr = builder.getIndexAttr(1);
3399 Value output = getOutput();
3400 int64_t outputRank = getOutputOperandRank();
3401 SmallVector<Range> loopBounds(outputRank);
3402 for (unsigned dim = 0; dim < outputRank; ++dim) {
3403 loopBounds[dim].offset = zeroAttr;
3404 // alphaH, alphaW, tileH, tileW, N, C
3405 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3406 loopBounds[dim].stride = oneAttr;
3407 }
3408 return loopBounds;
3409}
3410
3411SmallVector<utils::IteratorType>
3412WinogradInputTransformOp::getLoopIteratorTypes() {
3413 int64_t outputRank = getOutputOperandRank();
3414 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3415 utils::IteratorType::parallel);
3416 return iteratorTypes;
3417}
3418
3419LogicalResult WinogradInputTransformOp::getResultTilePosition(
3420 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3421 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3422 SmallVector<OpFoldResult> &resultSizes) {
3423 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3424 ShapedType outputType = getOutputOperandType();
3425 ArrayRef<int64_t> outputShape = outputType.getShape();
3426 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3427 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3428
3429 WinogradConv2DFmr fmr = getFmr();
3430 int64_t m, r;
3431 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3432 int64_t alpha = m + r - 1;
3433 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3434 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3435
3436 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3437 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3438
3439 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3440 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3441 offsets[getOutputCDim()]});
3442 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3443 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3444 sizes[getOutputCDim()]});
3445
3446 return success();
3447}
3448
3449/// Implement tiling for winograd_input_transform
3450/// The input of winograd_input_transform is (N, H, W, C).
3451/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3452/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3453/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3454/// the values for the sizes of tileH, tileW, N, C for one tile.
3455FailureOr<TilingResult>
3456WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3457 ArrayRef<OpFoldResult> offsets,
3458 ArrayRef<OpFoldResult> sizes) {
3459 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3460 WinogradConv2DFmr fmr = getFmr();
3461 int64_t m, r;
3462 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3463
3464 ShapedType outputType = getOutputOperandType();
3465 ArrayRef<int64_t> outputShape = outputType.getShape();
3466 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3467 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3468
3469 Location loc = getLoc();
3470 MLIRContext *context = builder.getContext();
3471 auto identityAffineMap =
3472 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3473 auto offsetAffineMap =
3474 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3475 Value mappedOffsetH = affine::makeComposedAffineApply(
3476 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3477 offsets[getOutputTileHDim()]);
3478 Value mappedOffsetW = affine::makeComposedAffineApply(
3479 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3480 offsets[getOutputTileWDim()]);
3481 auto sizeAffineMap = AffineMap::get(
3482 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3483 Value mappedSizeH = affine::makeComposedAffineApply(
3484 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3485 Value mappedSizeW = affine::makeComposedAffineApply(
3486 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3487
3488 SmallVector<Value> tiledOperands;
3489 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3490
3491 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3492 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3493 sliceOffsets.append(
3494 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3495 OpFoldResult sizeH =
3496 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3497 OpFoldResult sizeW =
3498 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3499 sliceSizes.append(
3500 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3501 int64_t inputRank = getInputOperandRank();
3502 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3503 auto inputSlice = tensor::ExtractSliceOp::create(
3504 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3505 tiledOperands.emplace_back(inputSlice);
3506
3507 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3508 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3509 resultSizes)))
3510 return failure();
3511
3512 int64_t outputRank = getOutputOperandRank();
3513 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3514 auto outputSlice = tensor::ExtractSliceOp::create(
3515 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3516 tiledOperands.emplace_back(outputSlice);
3517
3518 SmallVector<Type> resultTypes;
3519 resultTypes.push_back(tiledOperands[1].getType());
3520 Operation *tiledOp =
3521 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3522
3523 return TilingResult{
3524 {tiledOp},
3525 SmallVector<Value>(tiledOp->getResults()),
3526 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3527}
3528
3529//===----------------------------------------------------------------------===//
3530// WinogradOutputTransformOp
3531//===----------------------------------------------------------------------===//
3532
3533LogicalResult WinogradOutputTransformOp::verify() {
3534 auto valueType = cast<ShapedType>(getValue().getType());
3535 ArrayRef<int64_t> valueShape = valueType.getShape();
3536 int64_t valueH = valueShape[getValueAlphaHDim()];
3537 int64_t valueW = valueShape[getValueAlphaWDim()];
3538 int64_t valueTileH = valueShape[getValueTileHDim()];
3539 int64_t valueTileW = valueShape[getValueTileWDim()];
3540 WinogradConv2DFmr fmr = getFmr();
3541 int64_t m, r;
3542 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3543 bool leftTransform = valueH != 1;
3544 bool rightTransform = valueW != 1;
3545
3546 int64_t outputRank = getOutputOperandRank();
3547 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3548 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3549 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3550 } else {
3551 if (valueH != (leftTransform ? m + r - 1 : 1))
3552 return emitOpError("expect input height equals to input tile size");
3553 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3554 }
3555 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3556 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3557 } else {
3558 if (valueW != (rightTransform ? m + r - 1 : 1))
3559 return emitOpError("expect input width equals to input tile size");
3560 expectedOutputShape[getOutputWDim()] =
3561 (rightTransform ? m : 1) * valueTileW;
3562 }
3563 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3564 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3565
3566 auto outputType = cast<ShapedType>(getOutput().getType());
3567 ArrayRef<int64_t> outputShape = outputType.getShape();
3568 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3569 return emitOpError("the output shape is not expected");
3570 }
3571 return success();
3572}
3573
3574SmallVector<Range>
3575WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3576 Location loc = getLoc();
3577 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3578 IntegerAttr oneAttr = builder.getIndexAttr(1);
3579 Value value = getValue();
3580 int64_t valueRank = getValueOperandRank();
3581 SmallVector<Range> loopBounds(valueRank);
3582 for (unsigned dim = 0; dim < valueRank; ++dim) {
3583 loopBounds[dim].offset = zeroAttr;
3584 // alphaH, alphaW, tileH, tileW, N, F
3585 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3586 loopBounds[dim].stride = oneAttr;
3587 }
3588 return loopBounds;
3589}
3590
3591SmallVector<utils::IteratorType>
3592WinogradOutputTransformOp::getLoopIteratorTypes() {
3593 int64_t valueRank = getValueOperandRank();
3594 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3595 utils::IteratorType::parallel);
3596 return iteratorTypes;
3597}
3598
3599LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3600 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3601 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3602 SmallVector<OpFoldResult> &resultSizes) {
3603 WinogradConv2DFmr fmr = getFmr();
3604 int64_t m, r;
3605 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3606
3607 Location loc = getLoc();
3608 MLIRContext *context = builder.getContext();
3609 auto identityAffineMap =
3610 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3611 auto affineMap =
3612 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3613
3614 ShapedType valueType = getValueOperandType();
3615 ArrayRef<int64_t> valueShape = valueType.getShape();
3616 int64_t valueH = valueShape[0];
3617 int64_t valueW = valueShape[1];
3618 Value mappedOffsetH = affine::makeComposedAffineApply(
3619 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3620 offsets[getValueTileHDim()]);
3621 Value mappedOffsetW = affine::makeComposedAffineApply(
3622 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3623 offsets[getValueTileWDim()]);
3624 Value mappedSizeH = affine::makeComposedAffineApply(
3625 builder, loc, affineMap, sizes[getValueTileHDim()]);
3626 Value mappedSizeW = affine::makeComposedAffineApply(
3627 builder, loc, affineMap, sizes[getValueTileWDim()]);
3628
3629 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3630 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3631 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3632 OpFoldResult sizeH =
3633 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3634 OpFoldResult sizeW =
3635 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3636
3637 resultOffsets.append(
3638 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3639 resultSizes.append(
3640 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3641 return success();
3642}
3643
3644/// Implement tiling for winograd_output_transform
3645/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3646/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3647/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3648/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3649/// for the sizes of tileH, tileW, N, F for one tile.
3650FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3651 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3652 ArrayRef<OpFoldResult> sizes) {
3653 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3654 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3655 Location loc = getLoc();
3656 SmallVector<Value> tiledOperands;
3657 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3658
3659 ShapedType valueType = getValueOperandType();
3660 ArrayRef<int64_t> valueShape = valueType.getShape();
3661 int64_t alphaH = valueShape[getValueAlphaHDim()];
3662 int64_t alphaW = valueShape[getValueAlphaWDim()];
3663 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3664 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3665
3666 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3667 offsets[getValueTileWDim()], offsets[getValueNDim()],
3668 offsets[getValueFDim()]});
3669 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3670 sizes[getValueTileWDim()], sizes[getValueNDim()],
3671 sizes[getValueFDim()]});
3672 int64_t valueRank = getValueOperandRank();
3673 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3674 auto valueSlice = tensor::ExtractSliceOp::create(
3675 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3676 tiledOperands.emplace_back(valueSlice);
3677
3678 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3679 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3680 resultSizes)))
3681 return failure();
3682
3683 int64_t outputRank = getOutputOperandRank();
3684 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3685 auto outputSlice = tensor::ExtractSliceOp::create(
3686 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3687 tiledOperands.emplace_back(outputSlice);
3688
3689 SmallVector<Type> resultTypes;
3690 resultTypes.push_back(tiledOperands[1].getType());
3691 Operation *tiledOp =
3692 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3693
3694 return TilingResult{
3695 {tiledOp},
3696 SmallVector<Value>(tiledOp->getResults()),
3697 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3698}
3699
3700//===----------------------------------------------------------------------===//
3701// LinalgDialect
3702// TODO: Merge with the LinalgDialect block at the bottom
3703//===----------------------------------------------------------------------===//
3704
3705// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3706static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3707 auto explicitRange = subMap.getResults();
3708 auto defaultRange = fullMap.getResults();
3709 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3710 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3711 llvm::set_union(explicitSet, defaultSet);
3712 return explicitSet == defaultSet;
3713}
3714
3715/// Check if the user defined map is valid broadcast map. Here broadcast
3716/// indexing maps are defined in context of corresponding default indexing maps
3717/// for the given Op. This way the check becomes very simple i.e just check the
3718/// number of result dims.
3719/// Returns true if the explictMap is broadcasted with respect to the
3720/// defaultMap.
3721static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3722 return explictMap.getNumResults() < defaultMap.getNumResults();
3723}
3724
3725/// Verifies the broadcast and transpose semantic sepecified by the explicit
3726/// indexing map for the MatmulOp \p op for each operand specified by \p
3727/// opIndex.
3728static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3729 unsigned opIndex) {
3730 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3731 SmallVector<AffineMap, 3> defaultIndexingMaps =
3732 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3733
3734 auto opIndexingMap = opIndexingMaps[opIndex];
3735 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3736 // Check general validity of indexing map results.
3737 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3738 return matmulOp->emitOpError()
3739 << "Unexpected dim expression in map result.";
3740
3741 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3742 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3743 return matmulOp->emitOpError()
3744 << "Invalid broadcast requested, should be (d2).";
3745 }
3746 return success();
3747 }
3748 return success();
3749}
3750
3751// Check general validity of input indexing map of
3752// BatchMatmulOp/BatchReduceMatmulOp.
3753template <typename OpTy>
3754static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3755 AffineMap opIndexingMap,
3756 AffineMap defaultIndexingMap, bool isLHS) {
3757 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3758 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3759 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3760 // Check the result dims are valid.
3761 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3762 return batchVariantMatmulOp->emitOpError()
3763 << "Unexpected result dim expression (outside the set of default "
3764 "result dims).";
3765
3766 // Check for valid number of result dims of input maps.
3767 if (opIndexingMap.getNumResults() > 3)
3768 return batchVariantMatmulOp->emitOpError()
3769 << "no. of result dim expressions exceeds 3.";
3770
3771 auto hasValidBatchDim = [](AffineMap map) {
3772 AffineExpr batchDim = map.getResult(0);
3773 return batchDim.isFunctionOfDim(0);
3774 };
3775
3776 // Check if the requested broadcast is valid.
3777 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3778 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3779 return batchVariantMatmulOp->emitOpError()
3780 << "Invalid broadcast requested.";
3781 } else if (!hasValidBatchDim(opIndexingMap)) {
3782 return batchVariantMatmulOp->emitOpError()
3783 << "Invalid batch dimension expression.";
3784 }
3785 return success();
3786}
3787
3788/// This function checks if the given AffineMap for the output of a
3789/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3790/// dimensions and if the output map result dimensions are valid.
3791template <typename OpTy>
3792static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3793 AffineMap opIndexingMap) {
3794 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3795 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3796 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3797 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3798 opIndexingMap.getNumResults() != 3) {
3799
3800 return batchVariantMatmulOp->emitOpError()
3801 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3802 << ").";
3803 }
3804 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3805 opIndexingMap.getNumResults() != 2) {
3806 return batchVariantMatmulOp->emitOpError()
3807 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3808 << ").";
3809 }
3810
3811 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3812 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3813 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3814 outputMap.getResult(1).isFunctionOfDim(1) &&
3815 outputMap.getResult(2).isFunctionOfDim(2)
3816 : outputMap.getResult(0).isFunctionOfDim(1) &&
3817 outputMap.getResult(1).isFunctionOfDim(2);
3818 };
3819
3820 if (!areValidOutputResultDim(opIndexingMap)) {
3821 return batchVariantMatmulOp->emitOpError()
3822 << "Invalid output map result dimension.";
3823 }
3824
3825 return success();
3826}
3827
3828/// Verifies the broadcast and transpose semantic specified by the explicit
3829/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3830/// specified by opIndex.
3831template <typename OpTy>
3832static LogicalResult
3834 unsigned opIndex) {
3835 SmallVector<AffineMap, 3> opIndexingMaps =
3836 batchVariantMatmulOp.getIndexingMapsArray();
3837 SmallVector<AffineMap, 3> defaultIndexingMaps =
3838 batchVariantMatmulOp.getDefaultIndexingMaps(
3839 batchVariantMatmulOp->getContext());
3840
3841 if (opIndexingMaps.size() != 3)
3842 return batchVariantMatmulOp->emitOpError()
3843 << "Indexing_map attribute must have 3 affine maps.";
3844
3845 auto opIndexingMap = opIndexingMaps[opIndex];
3846 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3847
3848 if (opIndex == 2 &&
3849 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3850 return failure();
3851
3852 if (opIndex != 2 &&
3853 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3854 defaultIndexingMap, opIndex == 0)))
3855 return failure();
3856
3857 return success();
3858}
3859
3860namespace mlir {
3861namespace linalg {
3862
3863std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3864 if (m == 2 && r == 3)
3865 return WinogradConv2DFmr::F_2_3;
3866 if (m == 4 && r == 3)
3867 return WinogradConv2DFmr::F_4_3;
3868 if (m == 2 && r == 5)
3869 return WinogradConv2DFmr::F_2_5;
3870 return std::nullopt;
3871}
3872
3873std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3874 switch (fmr) {
3875 case WinogradConv2DFmr::F_2_3:
3876 return {2, 3};
3877 case WinogradConv2DFmr::F_4_3:
3878 return {4, 3};
3879 case WinogradConv2DFmr::F_2_5:
3880 return {2, 5};
3881 }
3882 llvm_unreachable("Unkown WinogradConv2DFmr");
3883}
3884
3885//===----------------------------------------------------------------------===//
3886// MatMulOp
3887//===----------------------------------------------------------------------===//
3888
3889static FailureOr<SmallVector<SmallVector<int64_t>>>
3892 for (auto map : maps) {
3893 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3894 if (!attr)
3895 return failure();
3897 for (auto result : attr.getAffineMap().getResults()) {
3898 auto dim = dyn_cast<AffineDimExpr>(result);
3899 if (!dim)
3900 return failure();
3901 pos.push_back(dim.getPosition());
3902 }
3903 positions.push_back(pos);
3904 }
3905 return positions;
3906}
3907
3908/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3909SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3910 AffineExpr d0, d1, d2;
3911 SmallVector<AffineMap> indexingMaps;
3912 bindDims(context, d0, d1, d2);
3913 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3914 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3915 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3916 return indexingMaps;
3917}
3918
3919bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3920 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3921 if (!maps)
3922 return false;
3923 if (maps.size() != 3)
3924 return false;
3925 auto positions = getAffineResultPositions(maps);
3926 if (failed(positions))
3927 return false;
3928 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3929 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3930 (*positions)[2] == SmallVector<int64_t>{0, 1};
3931}
3932
3933SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3934 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3935 utils::IteratorType::parallel,
3936 utils::IteratorType::reduction};
3937}
3938
3939unsigned MatmulOp::getNumRegionArgs() { return 3; }
3940
3941std::string MatmulOp::getLibraryCallName() {
3942 return generateLibraryCallName(getOperation());
3943}
3944
3945bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3946
3947/// Check if the op has broadcast and/or transpose semantic. Returns true if
3948/// the user defined indexing maps are not equal to default map.
3949bool MatmulOp::hasUserDefinedMaps() {
3950 SmallVector<AffineMap, 3> defaultMaps =
3951 getDefaultIndexingMaps(this->getContext());
3952 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3953 return defaultMaps != explicitMaps;
3954}
3955
3956/// Implements the block region builder for the MatmulOp. This is called by
3957/// 'fillStructuredOpRegion'.
3958void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3959 ArrayRef<NamedAttribute> attrs,
3960 function_ref<InFlightDiagnostic()> emitError) {
3961 if (emitError && block.getNumArguments() != 3) {
3962 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3963 << block.getNumArguments();
3964 return;
3965 }
3966 assert(block.getNumArguments() == 3 &&
3967 "MatmulOp regionBuilder expects 3 args");
3968 RegionBuilderHelper helper(b, block);
3969 SmallVector<Value> yields;
3970
3971 TypeFn castVal = TypeFn::cast_signed;
3972 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3973 return attr.getName() == "cast";
3974 });
3975 if (castIter != attrs.end()) {
3976 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3977 castVal = attr.getValue();
3978 }
3979
3980 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3981 block.getArgument(0));
3982 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3983 block.getArgument(1));
3984 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3985 if (!value1 || !value2 || !value3)
3986 return;
3987 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3988 value3, emitError);
3989 if (!value4)
3990 return;
3991 yields.push_back(value4);
3992 helper.yieldOutputs(yields);
3993}
3994
3995/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3996/// broadcast map must include K dimension.
3997/// TODO: Strict inclusion of K dimension in the broadcast map is not
3998/// necessary for both input matrices simultaneously. We can relax this
3999/// condition to have K dimension for one input matrix map and infer the K
4000/// dimension for other input matrix map from the one already having K
4001/// dimension.
4002bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
4003 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
4004 AffineExpr expr = bcastMap.getResult(0);
4005 // Invalid map if the common dimension of matmul not found.
4006 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
4007}
4008
4009static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
4010 if (parser.parseOptionalKeyword("indexing_maps"))
4011 return ArrayAttr{
4012 nullptr}; // Success in case indexing_maps was not provided.
4013
4014 ArrayAttr arrayAttr;
4015 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
4016 return failure();
4017
4018 if (llvm::any_of(arrayAttr,
4019 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
4020 return parser.emitError(parser.getCurrentLocation())
4021 << "element of indexing_maps array is not an affine_map";
4022
4023 return arrayAttr;
4024}
4025
4026ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4027 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4028 if (failed(indexingMapsAttr))
4029 return failure();
4030
4031 if (*indexingMapsAttr == nullptr) {
4032 auto indexingMapAttrs = llvm::map_to_vector(
4033 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
4034 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4035 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
4036 }
4037
4038 result.addAttribute("indexing_maps", *indexingMapsAttr);
4039 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
4040 MatmulOp::getRegionBuilder());
4041}
4042
4043void MatmulOp::print(OpAsmPrinter &p) {
4044 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4045 MatmulOp::getDefaultIndexingMaps(getContext()),
4046 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4047 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4048 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4049
4050 std::array<StringRef, 3> elidedAttrs = {
4051 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4052 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4053 elidedAttrs);
4054}
4055
4056/// Verify the user defined indexing maps.
4057LogicalResult MatmulOp::verify() {
4058 // Verification of pure matmul is handled by verifyStructuredOpInterface().
4059 if (!hasUserDefinedMaps())
4060 return success();
4061
4062 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
4063 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
4064 return failure();
4065 }
4066 return success();
4067}
4068
4069LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4070 return memref::foldMemRefCast(*this);
4071}
4072
4073void MatmulOp::getEffects(
4074 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4075 &effects) {
4076 if (hasPureTensorSemantics())
4077 return;
4078 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4079}
4080
4081Speculation::Speculatability MatmulOp::getSpeculatability() {
4082 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4083}
4084
4085SmallVector<AffineMap>
4086MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4087 AffineExpr d0, d1, d2;
4088 MLIRContext *context = builder.getContext();
4089 bindDims(context, d0, d1, d2);
4090 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
4091 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
4092 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4093 return {mapLHS, mapRHS, mapOut};
4094}
4095
4097 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4098 if (!maps)
4099 return false;
4100 if (maps.size() != 3)
4101 return false;
4102 auto positions = getAffineResultPositions(maps);
4103 if (failed(positions))
4104 return false;
4105 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
4106 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4107 (*positions)[2] == SmallVector<int64_t>{0, 1};
4108}
4109
4112 ValueRange inputs, ValueRange outputs,
4113 ArrayRef<NamedAttribute> attributes) {
4114 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4115 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4116}
4117
4120 ValueRange inputs, ValueRange outputs,
4121 ArrayRef<NamedAttribute> attributes) {
4122 OperationState state(location, getOperationName());
4123 build(builder, state, inputs, outputs, attributes);
4124 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4125 assert(res && "builder didn't return the right type");
4126 return res;
4127}
4128
4131 TypeRange resultTensorTypes,
4132 ValueRange inputs, ValueRange outputs,
4133 ArrayRef<NamedAttribute> attributes) {
4134 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4135 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4136}
4137
4140 TypeRange resultTensorTypes, ValueRange inputs,
4141 ValueRange outputs,
4142 ArrayRef<NamedAttribute> attributes) {
4143 OperationState state(location, getOperationName());
4144 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4145 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4146 assert(res && "builder didn't return the right type");
4147 return res;
4148}
4149
4152 TypeRange resultTensorTypes,
4153 ValueRange inputs, ValueRange outputs,
4154 Attribute cast,
4155 ArrayRef<NamedAttribute> attributes) {
4156 result.addAttribute("cast", cast);
4157 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4158 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4159}
4160
4163 TypeRange resultTensorTypes, ValueRange inputs,
4164 ValueRange outputs, Attribute cast,
4165 ArrayRef<NamedAttribute> attributes) {
4166 OperationState state(location, getOperationName());
4167 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4168 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4169 assert(res && "builder didn't return the right type");
4170 return res;
4171}
4172
4174 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4176 op->getAttr("indexing_maps"));
4177}
4178
4180MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4181 AffineExpr d0, d1, d2;
4182 MLIRContext *context = builder.getContext();
4183 bindDims(context, d0, d1, d2);
4184 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4185 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4186 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4187 return {mapLHS, mapRHS, mapOut};
4188}
4189
4191 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4192 if (!maps)
4193 return false;
4194 if (maps.size() != 3)
4195 return false;
4196 auto positions = getAffineResultPositions(maps);
4197 if (failed(positions))
4198 return false;
4199 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4200 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4201 (*positions)[2] == SmallVector<int64_t>{0, 1};
4202}
4203
4206 ValueRange inputs, ValueRange outputs,
4207 ArrayRef<NamedAttribute> attributes) {
4208 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4209 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4210}
4211
4214 ValueRange inputs, ValueRange outputs,
4215 ArrayRef<NamedAttribute> attributes) {
4216 OperationState state(location, getOperationName());
4217 build(builder, state, inputs, outputs, attributes);
4218 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4219 assert(res && "builder didn't return the right type");
4220 return res;
4221}
4222
4225 TypeRange resultTensorTypes,
4226 ValueRange inputs, ValueRange outputs,
4227 ArrayRef<NamedAttribute> attributes) {
4228 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4229 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4230}
4231
4234 TypeRange resultTensorTypes, ValueRange inputs,
4235 ValueRange outputs,
4236 ArrayRef<NamedAttribute> attributes) {
4237 OperationState state(location, getOperationName());
4238 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4239 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4240 assert(res && "builder didn't return the right type");
4241 return res;
4242}
4243
4246 TypeRange resultTensorTypes,
4247 ValueRange inputs, ValueRange outputs,
4248 Attribute cast,
4249 ArrayRef<NamedAttribute> attributes) {
4250 result.addAttribute("cast", cast);
4251 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4252 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4253}
4254
4257 TypeRange resultTensorTypes, ValueRange inputs,
4258 ValueRange outputs, Attribute cast,
4259 ArrayRef<NamedAttribute> attributes) {
4260 OperationState state(location, getOperationName());
4261 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4262 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4263 assert(res && "builder didn't return the right type");
4264 return res;
4265}
4266
4268 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4270 op->getAttr("indexing_maps"));
4271}
4272
4274BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4275 AffineExpr d0, d1, d2, d3;
4276 MLIRContext *context = builder.getContext();
4277 bindDims(context, d0, d1, d2, d3);
4278 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4279 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4280 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4281 return {mapLHS, mapRHS, mapOut};
4282}
4283
4285 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4286 if (!maps)
4287 return false;
4288 if (maps.size() != 3)
4289 return false;
4290 auto positions = getAffineResultPositions(maps);
4291 if (failed(positions))
4292 return false;
4293 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4294 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4295 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4296}
4297
4299 OpBuilder &builder, OperationState &result, ValueRange inputs,
4300 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4301 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4302 BatchMatmulOp::getRegionBuilder(),
4303 getDefaultIndexingMaps(builder));
4304}
4305
4308 ValueRange inputs, ValueRange outputs,
4309 ArrayRef<NamedAttribute> attributes) {
4310 OperationState state(location, getOperationName());
4311 build(builder, state, inputs, outputs, attributes);
4312 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4313 assert(res && "builder didn't return the right type");
4314 return res;
4315}
4316
4318 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4319 ValueRange inputs, ValueRange outputs,
4320 ArrayRef<NamedAttribute> attributes) {
4321 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4322 BatchMatmulOp::getRegionBuilder(),
4323 getDefaultIndexingMaps(builder));
4324}
4325
4328 TypeRange resultTensorTypes, ValueRange inputs,
4329 ValueRange outputs,
4330 ArrayRef<NamedAttribute> attributes) {
4331 OperationState state(location, getOperationName());
4332 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4333 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4334 assert(res && "builder didn't return the right type");
4335 return res;
4336}
4337
4339 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4340 ValueRange inputs, ValueRange outputs, Attribute cast,
4341 ArrayRef<NamedAttribute> attributes) {
4342 result.addAttribute("cast", cast);
4343 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4344 BatchMatmulOp::getRegionBuilder(),
4345 getDefaultIndexingMaps(builder));
4346}
4347
4350 TypeRange resultTensorTypes, ValueRange inputs,
4351 ValueRange outputs, Attribute cast,
4352 ArrayRef<NamedAttribute> attributes) {
4353 OperationState state(location, getOperationName());
4354 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4355 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4356 assert(res && "builder didn't return the right type");
4357 return res;
4358}
4359
4361 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4363 op->getAttr("indexing_maps"));
4364}
4365
4367BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4368 AffineExpr d0, d1, d2, d3;
4369 MLIRContext *context = builder.getContext();
4370 bindDims(context, d0, d1, d2, d3);
4371 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4372 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4373 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4374 return {mapLHS, mapRHS, mapOut};
4375}
4376
4378 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4379 if (!maps)
4380 return false;
4381 if (maps.size() != 3)
4382 return false;
4383 auto positions = getAffineResultPositions(maps);
4384 if (failed(positions))
4385 return false;
4386 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4387 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4388 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4389}
4390
4392 OpBuilder &builder, OperationState &result, ValueRange inputs,
4393 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4394 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4395 BatchMatmulOp::getRegionBuilder(),
4396 getDefaultIndexingMaps(builder));
4397}
4398
4401 ValueRange inputs, ValueRange outputs,
4402 ArrayRef<NamedAttribute> attributes) {
4403 OperationState state(location, getOperationName());
4404 build(builder, state, inputs, outputs, attributes);
4405 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4406 assert(res && "builder didn't return the right type");
4407 return res;
4408}
4409
4411 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4412 ValueRange inputs, ValueRange outputs,
4413 ArrayRef<NamedAttribute> attributes) {
4414 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4415 BatchMatmulOp::getRegionBuilder(),
4416 getDefaultIndexingMaps(builder));
4417}
4418
4421 TypeRange resultTensorTypes, ValueRange inputs,
4422 ValueRange outputs,
4423 ArrayRef<NamedAttribute> attributes) {
4424 OperationState state(location, getOperationName());
4425 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4426 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4427 assert(res && "builder didn't return the right type");
4428 return res;
4429}
4430
4432 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4433 ValueRange inputs, ValueRange outputs, Attribute cast,
4434 ArrayRef<NamedAttribute> attributes) {
4435 result.addAttribute("cast", cast);
4436 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4437 BatchMatmulOp::getRegionBuilder(),
4438 getDefaultIndexingMaps(builder));
4439}
4440
4443 TypeRange resultTensorTypes, ValueRange inputs,
4444 ValueRange outputs, Attribute cast,
4445 ArrayRef<NamedAttribute> attributes) {
4446 OperationState state(location, getOperationName());
4447 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4448 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4449 assert(res && "builder didn't return the right type");
4450 return res;
4451}
4452
4454 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4456 op->getAttr("indexing_maps"));
4457}
4458
4459//===----------------------------------------------------------------------===//
4460// ContractOp
4461//===----------------------------------------------------------------------===//
4462
4463SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4464 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4465 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4466 // domains are all the same, and each implements a projected permutation.
4467 // Each iteration space dim must occur for at least one operand and either
4468 // takes part in a contraction/reduction or else has parallel iteration type.
4469 // We have that a dim is a contraction/reduction dim if and only if the dim
4470 // occurs for the output operand. We use this fact for fast inference:
4471 // NB: In case we allow dims to occur solely for one input, the above still
4472 // holds: per the einsum semantics, these are reduction dims as well.
4473 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4474 for (auto result : outAffineMap.getResults()) {
4475 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4476 assert(dimExpr && "affine_map is a projected permutation");
4477 dimsInOutput[dimExpr.getPosition()] = true;
4478 }
4479
4481 for (auto dimOccursInOutput : dimsInOutput)
4482 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4483 : utils::IteratorType::reduction);
4484
4485 return iteratorTypes;
4486}
4487
4488unsigned ContractOp::getNumRegionArgs() { return 3; }
4489
4490/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4491void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4492 ArrayRef<NamedAttribute> attrs,
4493 function_ref<InFlightDiagnostic()> emitError) {
4494 if (emitError && block.getNumArguments() != 3) {
4495 emitError() << "ContractOp regionBuilder expects 3 args, got "
4496 << block.getNumArguments();
4497 return;
4498 }
4499 assert(block.getNumArguments() == 3 &&
4500 "ContractOp regionBuilder expects 3 args");
4501 RegionBuilderHelper helper(b, block);
4502
4503 TypeFn castSignedness = TypeFn::cast_signed;
4504 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4505 return attr.getName() == "cast";
4506 });
4507 if (castIter != attrs.end()) {
4508 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4509 castSignedness = attr.getValue();
4510 }
4511
4512 // TODO: Support fields with operators besides mult & add.
4513 Type outType = block.getArgument(2).getType();
4514 Value lhsAtOutType =
4515 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4516 Value rhsAtOutType =
4517 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4518 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4519 rhsAtOutType, emitError);
4520 if (!productAtOutType)
4521 return;
4522 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4523 productAtOutType, emitError);
4524 if (!result)
4525 return;
4526 helper.yieldOutputs({result});
4527}
4528
4529ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4530 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4531 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4532 return parser.emitError(parser.getCurrentLocation(),
4533 "expected 'indexing_maps' attribute");
4534 result.addAttribute("indexing_maps", *indexingMapsAttr);
4535
4536 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4537 regionBuilder);
4538}
4539
4540void ContractOp::print(OpAsmPrinter &p) {
4541 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4543 p, getOperation(), getInputs(), getOutputs(),
4544 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4545}
4546
4547LogicalResult ContractOp::verify() {
4548 int iterationSpaceDims = -1;
4549 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4550 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4551 // access an input operand (so occurrence count can be at most 2) and
4552 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4553 SmallVector<size_t> inOccurrences;
4554 SmallVector<size_t> outOccurrences;
4555
4556 // A helper so that for each operand's affine_map and type we check that ...
4557 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4558 bool isInput) -> LogicalResult {
4559 // ... the affine_map is a projected permutation;
4560 if (!affineMap.isProjectedPermutation())
4561 return emitError("provided affine_map is not a projected permutation");
4562
4563 // ... the rank of the affine_map's results and corresponding type match;
4564 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4565 if (affineMap.getNumResults() != shapedType.getRank())
4566 return emitError("ranks of shaped operand and results of corresponding "
4567 "affine_map differ");
4568 } else if (affineMap.getNumResults() != 0) {
4569 return emitError("affine_map specifies shaped access while operand has "
4570 "non-shaped type");
4571 }
4572
4573 // ... the rank of the affine_map's domain is the same as those seen prior;
4574 if (iterationSpaceDims == -1) {
4575 iterationSpaceDims = affineMap.getNumDims();
4576 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4577 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4578 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4579 return emitError("iteration spaces of provided affine_maps differ");
4580 }
4581
4582 // ... update counts of dims used to access either an input or the output.
4583 for (AffineExpr affineExpr : affineMap.getResults()) {
4584 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4585 if (!affineDimExpr)
4586 llvm_unreachable("affine_map is a projected permutation");
4587
4588 if (isInput)
4589 inOccurrences[affineDimExpr.getPosition()] += 1;
4590 else
4591 outOccurrences[affineDimExpr.getPosition()] += 1;
4592 }
4593
4594 return success();
4595 };
4596
4597 for (auto &&[affineMap, operandType, isInput] :
4598 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4599 SmallVector<bool>{true, true, false})) {
4600 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4601 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4602 }
4603
4604 bool hasContractingDim = false;
4605 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4606 size_t inOccCount = inOccurrences[dimIndex];
4607 size_t outOccCount = outOccurrences[dimIndex];
4608
4609 // We have a contracting dim if and only if ...
4610 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4611
4612 if (inOccCount == 0 && outOccCount == 0)
4613 return emitError() << "iteration space dim at index " << dimIndex
4614 << " not used to access any operand";
4615
4616 // NB: We disallow a dim which occurs for only one input operand and not
4617 // for the output. In terms of einsum semantics such dims have a
4618 // sensible meaning - namely an additional reduction per each such dim.
4619 // By contrast, the ContractionOpInterface does not know about this
4620 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4621 // while vector.contract's verifier accepts dims of this kind many of
4622 // its lowerings give up on encountering these dims.
4623 // TODO: Remove following once we have comprehensive support for input-only
4624 // reduction dims, at both the linalg- and vector-dialect levels.
4625 if (inOccCount == 1 && outOccCount != 1)
4626 return emitError()
4627 << "iteration space dim at index " << dimIndex
4628 << " is neither a contracting dim nor of parallel iteration type";
4629 }
4630
4631 if (!hasContractingDim)
4632 return emitError("'indexing_maps' do not specify a contracting dimension");
4633
4634 return success();
4635}
4636
4637LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4638 return memref::foldMemRefCast(*this);
4639}
4640
4641void ContractOp::getEffects(
4642 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4643 &effects) {
4644 if (hasPureTensorSemantics())
4645 return;
4646 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4647}
4648
4649Speculation::Speculatability ContractOp::getSpeculatability() {
4650 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4651}
4652
4653//===----------------------------------------------------------------------===//
4654// Implementation of BatchMatmulOp
4655//===----------------------------------------------------------------------===//
4656SmallVector<AffineMap>
4657BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4658 AffineExpr d0, d1, d2, d3;
4659 SmallVector<AffineMap> indexingMaps;
4660 bindDims(context, d0, d1, d2, d3);
4661 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4662 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4663 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4664 return indexingMaps;
4665}
4666
4667bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4668 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4669 if (!maps)
4670 return false;
4671 if (maps.size() != 3)
4672 return false;
4673 auto positions = getAffineResultPositions(maps);
4674 if (failed(positions))
4675 return false;
4676 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4677 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4678 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4679}
4680
4681SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4682 return SmallVector<utils::IteratorType>{
4683 utils::IteratorType::parallel, utils::IteratorType::parallel,
4684 utils::IteratorType::parallel, utils::IteratorType::reduction};
4685}
4686
4687unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4688
4689std::string BatchMatmulOp::getLibraryCallName() {
4690 return generateLibraryCallName(getOperation());
4691}
4692
4693/// Check if the op has broadcast and/or transpose semantic. Returns true if
4694/// the user defined indexing maps are not equal to default map.
4695bool BatchMatmulOp::hasUserDefinedMaps() {
4696 SmallVector<AffineMap, 3> defaultMaps =
4697 getDefaultIndexingMaps(this->getContext());
4698 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4699 return defaultMaps != explicitMaps;
4700}
4701
4702/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4703/// broadcast map must include K dimension.
4704/// TODO: Strict inclusion of K dimension in the broadcast map is not
4705/// necessary for both input matrices simultaneously. We can relax this
4706/// condition to have K dimension for one input matrix map and infer the K
4707/// dimension for other input matrix map from the one already having K
4708/// dimension.
4709bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4710 assert(bcastMap.getNumResults() < 3 &&
4711 "Expected less than 3 result dim expr.");
4712 bool isValid = false;
4713 enum Indices { batchPos, mPos, nPos, kPos };
4714 if (bcastMap.getNumResults() == 1) {
4715 AffineExpr expr = bcastMap.getResult(0);
4716 isValid = expr.isFunctionOfDim(kPos);
4717 } else if (bcastMap.getNumResults() == 2) {
4718 AffineExpr expr0 = bcastMap.getResult(0);
4719 AffineExpr expr1 = bcastMap.getResult(1);
4720 isValid =
4721 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4722 expr0.isFunctionOfDim(mPos)) &&
4723 expr1.isFunctionOfDim(kPos))
4724 : ((expr0.isFunctionOfDim(batchPos) &&
4725 expr1.isFunctionOfDim(kPos)) ||
4726 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4727 }
4728 return isValid;
4729}
4730
4731void BatchMatmulOp::regionBuilder(
4732 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4733 function_ref<InFlightDiagnostic()> emitError) {
4734 if (emitError && block.getNumArguments() != 3) {
4735 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4736 << block.getNumArguments();
4737 return;
4738 }
4739 assert(block.getNumArguments() == 3 &&
4740 "BatchMatmulOp regionBuilder expects 3 args");
4741 RegionBuilderHelper helper(b, block);
4742 SmallVector<Value> yields;
4743
4744 TypeFn castVal = TypeFn::cast_signed;
4745 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4746 return attr.getName() == "cast";
4747 });
4748 if (castIter != attrs.end()) {
4749 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4750 castVal = attr.getValue();
4751 }
4752
4753 auto toType = block.getArgument(2).getType();
4754 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4755 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4756 Value mulVal =
4757 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
4758 if (!castValA || !castValB || !mulVal)
4759 return;
4760 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4761 mulVal, emitError);
4762 if (!addVal)
4763 return;
4764 yields.push_back(addVal);
4765 helper.yieldOutputs(yields);
4766}
4767
4768ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4769 SmallVector<Attribute, 3> indexingMapsAttr;
4770 Attribute mapAttr;
4771 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4772 if (parser.parseEqual())
4773 return failure();
4774
4775 if (parser.parseLSquare())
4776 return failure();
4777
4778 do {
4779 if (parser.parseAttribute(mapAttr))
4780 return failure();
4781 if (!isa<AffineMapAttr>(mapAttr)) {
4782 return parser.emitError(parser.getCurrentLocation(),
4783 "expected affine map attribute");
4784 }
4785 indexingMapsAttr.push_back(mapAttr);
4786
4787 if (parser.parseOptionalComma())
4788 break;
4789 } while (true);
4790
4791 if (parser.parseRSquare())
4792 return failure();
4793 }
4794 // Initialize indexingMaps, if not supplied explicitly.
4795 if (indexingMapsAttr.empty()) {
4796 indexingMapsAttr = llvm::map_to_vector(
4797 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4798 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4799 }
4800 result.addAttribute("indexing_maps",
4801 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4802
4803 return ::parseNamedStructuredOp(parser, result,
4804 BatchMatmulOp::getNumRegionArgs(),
4805 BatchMatmulOp::getRegionBuilder());
4806}
4807
4808void BatchMatmulOp::print(OpAsmPrinter &p) {
4809 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4810 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4811 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4812 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4813 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4814
4815 std::array<StringRef, 3> elidedAttrs = {
4816 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4817 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4818 elidedAttrs);
4819}
4820
4821/// Verify the user defined indexing maps.
4822LogicalResult BatchMatmulOp::verify() {
4823 // Verification of pure batch_matmul is handled by
4824 // verifyStructuredOpInterface().
4825 if (!hasUserDefinedMaps())
4826 return success();
4827
4828 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4830 return failure();
4831 }
4832 return success();
4833}
4834
4835LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4836 SmallVectorImpl<OpFoldResult> &) {
4837 return memref::foldMemRefCast(*this);
4838}
4839
4840void BatchMatmulOp::getEffects(
4841 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4842 &effects) {
4843 if (hasPureTensorSemantics())
4844 return;
4845 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4846}
4847
4848Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4849 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4850}
4851
4852//===----------------------------------------------------------------------===//
4853// ElementwiseOp
4854//===----------------------------------------------------------------------===//
4855//
4856namespace {
4857struct ArityGroupAndKind {
4858 // The enum class {Unary, Binary, Ternary, ..}
4859 ElementwiseArityGroup arityGroup;
4860
4861 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4862 union Kind {
4863 UnaryFn unaryFn;
4864 BinaryFn binaryFn;
4865 TernaryFn ternaryFn;
4866 } kind;
4867};
4868
4869unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4870 return static_cast<unsigned>(arityGroup);
4871}
4872} // namespace
4873
4874static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4875 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4876 constexpr int lastBinary =
4877 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4878 constexpr int lastTernary =
4879 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4880
4881 int val = static_cast<int>(kind);
4882 ArityGroupAndKind result;
4883
4884 if (val < lastUnary) {
4885 result.arityGroup = ElementwiseArityGroup::Unary;
4886 result.kind.unaryFn = static_cast<UnaryFn>(val);
4887 return result;
4888 }
4889 if (val < lastBinary) {
4890 result.arityGroup = ElementwiseArityGroup::Binary;
4891 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4892 return result;
4893 }
4894 if (val >= lastTernary) {
4895 llvm_unreachable("unhandled ElementwiseFn");
4896 }
4897 result.arityGroup = ElementwiseArityGroup::Ternary;
4898 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4899 return result;
4900}
4901
4902SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4903 auto rank = getResultRank();
4904 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4905}
4906
4908ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4909 MLIRContext *context) {
4910 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4911 return SmallVector<AffineMap>(numMaps, map);
4912}
4913
4914ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4915 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4916 Attribute attr;
4917 mlir::linalg::ElementwiseKind elemwiseKindVal;
4918 if (parser.parseKeyword("kind") || parser.parseEqual())
4919 return failure();
4920
4921 if (succeeded(parser.parseAttribute(attr))) {
4922 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4923 if (!elemwiseKindAttr)
4924 return parser.emitError(parser.getCurrentLocation(),
4925 "expected ElementwiseKind attribute");
4926 elemwiseKindVal = elemwiseKindAttr.getValue();
4927 } else {
4928 return parser.emitError(parser.getCurrentLocation(),
4929 "expected operation 'kind' attribute");
4930 }
4931 result.addAttribute(
4932 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4933
4934 // Parse optional `indexing_maps`
4935 SmallVector<Attribute, 3> indexingMapsAttr;
4936 Attribute mapAttr;
4937 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4938 if (parser.parseEqual())
4939 return failure();
4940 if (parser.parseLSquare())
4941 return failure();
4942 do {
4943 if (parser.parseAttribute(mapAttr))
4944 return failure();
4945 if (!isa<AffineMapAttr>(mapAttr))
4946 return parser.emitError(parser.getCurrentLocation(),
4947 "expected affine map attribute");
4948 indexingMapsAttr.push_back(mapAttr);
4949 if (parser.parseOptionalComma())
4950 break;
4951 } while (true);
4952 if (parser.parseRSquare())
4953 return failure();
4954 }
4955 // At this stage of parsing the only way to infer number of region
4956 // args is through op kind, as input output tensors are not parsed yet.
4957 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4958 int numRegionArgs =
4959 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4960 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4961 ElementwiseOp::getRegionBuilder())) {
4962 return parser.emitError(parser.getCurrentLocation(),
4963 "unable to parse elemwise op");
4964 }
4965
4966 // Initialize indexingMaps, if not supplied explicitly.
4967 if (indexingMapsAttr.empty()) {
4968 // We need to infer the numDims of the indexing maps from the output
4969 // type which is already parsed by now.
4970 auto resultType = result.operands[result.operands.size() - 1].getType();
4971 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4972 if (!shapedType)
4973 return parser.emitError(parser.getCurrentLocation(),
4974 "return type needs to be shaped type");
4975 auto numDims = shapedType.getRank();
4976 indexingMapsAttr = llvm::map_to_vector(
4977 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4978 parser.getContext()),
4979 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4980 }
4981
4982 result.addAttribute("indexing_maps",
4983 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4984 return success();
4985}
4986
4987void ElementwiseOp::print(OpAsmPrinter &p) {
4988 p << " kind=";
4989 p.printAttribute(getKindAttr());
4990 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4991 "indexing_maps"};
4992 unsigned arity =
4993 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4994 unsigned numDims = getResultRank();
4995
4996 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4997 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4998 getContext()),
4999 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5000
5001 if (!llvm::equal(getIndexingMaps(), indexingMaps))
5002 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
5003
5004 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
5005 elidedAttrs);
5006}
5007
5008/// Implements the block region builder for the ElementwiseOp. This is called by
5009/// 'fillStructuredOpRegion'.
5010void ElementwiseOp::regionBuilder(
5011 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
5012 function_ref<InFlightDiagnostic()> emitError) {
5013 ElementwiseKind elemwiseKind;
5014 for (auto attr : attrs) {
5015 if (attr.getName() == b.getStringAttr("kind")) {
5016 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
5017 assert(kindAttr && "op kind attribute incorrectly set");
5018 elemwiseKind = kindAttr.getValue();
5019 break;
5020 }
5021 }
5022
5023 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
5024 auto arityGroup = groupAndKind.arityGroup;
5025 auto kind = groupAndKind.kind;
5026 if (emitError && block.getNumArguments() !=
5027 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
5028 emitError() << "Elementwise regionBuilder expects "
5029 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
5030 << block.getNumArguments();
5031 return;
5032 }
5033 assert(block.getNumArguments() ==
5034 getArityGroupAsUInt(arityGroup) + 1 /*output*/
5035 && "Elementwise regionBuilder number of block args mismatch");
5036
5037 RegionBuilderHelper helper(b, block);
5038 SmallVector<Value> yields;
5039 Value result;
5040
5041 if (arityGroup == ElementwiseArityGroup::Unary) {
5042 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
5043
5044 } else if (arityGroup == ElementwiseArityGroup::Binary) {
5045 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
5046 block.getArgument(1));
5047
5048 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
5049 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
5050 block.getArgument(1), block.getArgument(2));
5051
5052 } else {
5053 assert(false && "found unhandled category in elemwise");
5054 }
5055
5056 yields.push_back(result);
5057 helper.yieldOutputs(yields);
5058}
5059
5060LogicalResult ElementwiseOp::fold(FoldAdaptor,
5061 SmallVectorImpl<OpFoldResult> &) {
5062 return memref::foldMemRefCast(*this);
5063}
5064
5065void ElementwiseOp::getEffects(
5066 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5067 &effects) {
5068 if (hasPureTensorSemantics())
5069 return;
5070 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
5071}
5072
5073Speculation::Speculatability ElementwiseOp::getSpeculatability() {
5074 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
5075}
5076
5077//===----------------------------------------------------------------------===//
5078// PackOp/UnPackOp Common
5079//===----------------------------------------------------------------------===//
5080
5081template <typename OpTy, typename>
5082SmallVector<int64_t>
5084 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5085 ? packOrUnPack.getDestType()
5086 : packOrUnPack.getSourceType();
5087 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5088 ? packOrUnPack.getSourceType()
5089 : packOrUnPack.getDestType();
5091 packedType.getShape().take_front(unpackedType.getRank()));
5092 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5094 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
5095 }
5096 return result;
5097}
5102
5103// Given the (potentially) updated packed type, `newPackedTy`, generates an
5104// updated mixed-tile-sizes list. For each inner packed dimension that is static
5105// in `newPackedTy`, the tile is set to that static size (replacing SSA values
5106// or mismatched constants). Dynamic packed dimensions preserve the original
5107// tile. The folded tensor type is treated as authoritative for static extents.
5108// Note - packed-type-dim and mixed-tile-size should always match!
5111 ArrayRef<OpFoldResult> mixedTiles) {
5112 SmallVector<OpFoldResult> newMixedTileSizes;
5113 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5114 .getShape()
5115 .take_back(mixedTiles.size()),
5116 mixedTiles)) {
5117 int64_t dimSize = std::get<0>(it);
5118 if (dimSize == ShapedType::kDynamic) {
5119 newMixedTileSizes.push_back(std::get<1>(it));
5120 continue;
5121 }
5122 newMixedTileSizes.push_back(rewriter.getIndexAttr(dimSize));
5123 }
5124
5125 return newMixedTileSizes;
5126}
5127
5128template <typename OpTy>
5129static LogicalResult
5131 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5132 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5133 "applies to only pack or unpack operations");
5134 int64_t destRank = op.getDestRank();
5135 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5136 for (auto dim : llvm::seq<int64_t>(0, destRank))
5137 reifiedReturnShapes[0][dim] =
5138 createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
5139 return success();
5140}
5141
5142template <typename OpTy>
5144 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5145 "applies to only pack or unpack operations");
5146 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5147 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5148 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5149 assert(tiles.size() == dimsToTile.size() &&
5150 "tiles must match indices of dimension to block");
5151 // bind the dimension `i` with the tile factor.
5152 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5153 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5154 return dimAndTileMapping;
5155}
5156
5157template <typename OpTy>
5159 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5160 "applies to only pack or unpack operations");
5161 Builder builder(op);
5162 SmallVector<OpFoldResult> mixedInnerTiles;
5163 unsigned dynamicValIndex = 0;
5164 for (int64_t staticTile : op.getStaticInnerTiles()) {
5165 if (ShapedType::isStatic(staticTile))
5166 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5167 else
5168 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5169 }
5170 return mixedInnerTiles;
5171}
5172
5173template <typename OpTy>
5175 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5176 "applies to only pack or unpack operations");
5177 SmallVector<Value> dynamicTiles;
5178 SmallVector<int64_t> staticTiles;
5179 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5180 return staticTiles;
5181}
5182
5183/// Returns true if `dimsPos` is invalid. It is invalid when:
5184/// a) It contains duplicate.
5185/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5186/// c) The number of elements in `dimsPos` is > than `rank`.
5188 size_t rank) {
5189 size_t dimsPosSize = dimsPos.size();
5190 if (dimsPosSize > rank)
5191 return true;
5192 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5193 if (dimsPosSize != uniqued.size())
5194 return true;
5195 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5196 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5197 });
5198}
5199
5200template <typename OpTy>
5201static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5202 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5203 "applies to only pack or unpack operations");
5204 Operation *op = packOrUnPack.getOperation();
5205
5206 // Return true if we have a zero-value tile.
5207 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5208 return llvm::any_of(tiles, [](OpFoldResult tile) {
5209 return isa<Attribute>(tile) && isZeroInteger(tile);
5210 });
5211 };
5212
5213 // Verify that the source and destination are ranked types.
5214 if (!packOrUnPack.getSourceType().hasRank() ||
5215 !packOrUnPack.getDestType().hasRank())
5216 return op->emitError("expected both source and destination to have rank");
5217
5218 // Verify that the Operation does not have mixed tensor/buffer semantics.
5219 if (!packOrUnPack.hasPureBufferSemantics() &&
5220 !packOrUnPack.hasPureTensorSemantics())
5221 return op->emitError("mixing tensor and buffer semantics is not allowed");
5222 const unsigned numResults = packOrUnPack.getNumResults();
5223 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5224 return op->emitError("expected 1 result, got ") << numResults;
5225 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5226 return op->emitError("expected 0 results, got ") << numResults;
5227
5228 // Verify tiles. Do not allow zero tiles.
5229 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5230 if (hasZeros(mixedTiles))
5231 return op->emitError("invalid zero tile factor");
5232
5233 // Verify inner_dims_pos and outer_dims_perm.
5234 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5235 ? packOrUnPack.getSourceType()
5236 : packOrUnPack.getDestType();
5237 size_t unpackedRank = unpackedType.getRank();
5238 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5239 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5240 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5241 return op->emitError("invalid inner_dims_pos vector");
5242 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5243 return op->emitError("invalid outer_dims_perm vector");
5244 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5245 return op->emitError("outer_dims_perm must be a permutation or empty");
5246
5247 // Tiling factors must be less than or equal to the input rank for pack (or
5248 // output rank for unpack), and must match the number of `inner_dims_pos`.
5249 if (mixedTiles.size() > unpackedRank) {
5250 return op->emitError("tiling factors must be less than or equal to the "
5251 "input rank for pack or output rank for unpack");
5252 }
5253 if (mixedTiles.size() != innerDimsPos.size()) {
5254 return op->emitError(
5255 "tiling factors must equal the number of dimensions to tile");
5256 }
5257
5258 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5259 ? packOrUnPack.getDestType()
5260 : packOrUnPack.getSourceType();
5261 size_t packedRank = packedType.getRank();
5262 // Require output rank to match input rank + number of blocking factors.
5263 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5264 if (expectedPackedRank != packedRank) {
5265 return op->emitError(
5266 "packed rank != (unpacked rank + num tiling factors), got ")
5267 << packedRank << " != " << expectedPackedRank;
5268 }
5269
5270 // Verify result shape is greater than the minimum expected
5271 // by the pack operation, and that the output shape
5272 // represents full tiles.
5273 SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
5274 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5275 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5276 for (auto it : llvm::enumerate(llvm::zip(
5277 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5278 int64_t dimSize = std::get<0>(it.value());
5279 if (Attribute attr =
5280 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5281 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5282 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5283 if (dimSize != staticTileSize)
5284 return op->emitError(
5285 "mismatch in inner tile sizes specified and shaped of "
5286 "tiled dimension in the packed type at index ")
5287 << it.index() << ": got " << dimSize << " != " << staticTileSize;
5288 } else if (!ShapedType::isDynamic(dimSize)) {
5289 return op->emitError("mismatch in inner tile sizes specified at index ")
5290 << it.index() << ": got static shape " << dimSize
5291 << " but dynamic tile size";
5292 }
5293 }
5294 if (failed(
5295 verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
5296 auto elementType = unpackedType.getElementType();
5297 Type expectedType, actualType;
5298 if (packOrUnPack.hasPureTensorSemantics()) {
5299 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5300 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5301 } else {
5302 expectedType = MemRefType::get(expectedPackedShape, elementType);
5303 actualType = MemRefType::get(packedType.getShape(), elementType);
5304 }
5305 return op->emitError("expected ")
5306 << expectedType << " for the packed domain value, got "
5307 << actualType;
5308 }
5309 return success();
5310}
5311
5312namespace {
5313/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5314/// various permutations to the op.
5315// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5316// these. These may or may not become true foldings / canonicalizations
5317// depending on how aggressive we want to be in automatically folding
5318// transposes.
5319struct PackOrUnPackTransposeResult {
5320 SmallVector<int64_t> innerDimsPos;
5321 SmallVector<OpFoldResult> innerTiles;
5322 SmallVector<int64_t> outerDimsPerm;
5323};
5324} // namespace
5325
5326template <typename OpTy>
5327static PackOrUnPackTransposeResult
5329 ArrayRef<int64_t> innerPermutation,
5330 ArrayRef<int64_t> outerPermutation) {
5331 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5332 "applies to only pack or unpack operations");
5333 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5334 "some permutation must be non-empty");
5335 PackOrUnPackTransposeResult metadata;
5336 metadata.innerDimsPos =
5337 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5338 metadata.innerTiles =
5339 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5340 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5341 ? packOrUnPackOp.getSourceRank()
5342 : packOrUnPackOp.getDestRank();
5343 metadata.outerDimsPerm =
5344 packOrUnPackOp.getOuterDimsPerm().empty()
5345 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5346 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5347 if (!innerPermutation.empty()) {
5348 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5349 isPermutationVector(innerPermutation) &&
5350 "invalid inner permutation");
5351 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5352 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5353 }
5354 if (!outerPermutation.empty()) {
5355 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5356 isPermutationVector(outerPermutation) &&
5357 "invalid outer permutation");
5358 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5359 }
5360 return metadata;
5361}
5362
5363//===----------------------------------------------------------------------===//
5364// PackOp
5365//===----------------------------------------------------------------------===//
5366
5367void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5368 if (!getResults().empty())
5369 setNameFn(getResult(), "pack");
5370}
5371
5372ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
5373 OpAsmParser::UnresolvedOperand source, dest;
5376 SmallVector<Type> paddingValueType;
5377 SmallVector<int64_t> staticTiles;
5378 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
5379 Type sourceType, destType, resultType;
5380
5381 if (parser.parseOperand(source))
5382 return failure();
5383
5384 if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
5385 if (parser.parseLParen() ||
5386 parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
5387 parser.parseColon() || parser.parseTypeList(paddingValueType) ||
5388 parser.parseRParen())
5389 return failure();
5390 }
5391
5392 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
5393 if (parser.parseEqual())
5394 return failure();
5395
5396 SmallVector<int64_t> outerDimsPermVec;
5398 int64_t value;
5399 if (parser.parseInteger(value))
5400 return failure();
5401 outerDimsPermVec.push_back(value);
5402 return success();
5403 }))
5404 return failure();
5405 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
5406 }
5407
5408 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
5409 return failure();
5410
5411 SmallVector<int64_t> innerDimsPosVec;
5413 int64_t value;
5414 if (parser.parseInteger(value))
5415 return failure();
5416 innerDimsPosVec.push_back(value);
5417 return success();
5418 }))
5419 return failure();
5420 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
5421
5422 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
5423 return failure();
5424
5425 DenseI64ArrayAttr staticTilesAttr;
5426 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
5427 return failure();
5428 for (auto val : staticTilesAttr.asArrayRef())
5429 staticTiles.push_back(val);
5430
5431 if (parser.parseKeyword("into") || parser.parseOperand(dest))
5432 return failure();
5433
5434 if (parser.parseOptionalAttrDict(result.attributes))
5435 return failure();
5436
5437 if (parser.parseColon() || parser.parseType(sourceType))
5438 return failure();
5439
5440 bool hasArrow = succeeded(parser.parseOptionalArrow());
5441 if (hasArrow) {
5442 if (parser.parseType(destType))
5443 return failure();
5444 }
5445
5446 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5447 if (!hasArrow) {
5448 return parser.emitError(parser.getCurrentLocation(),
5449 "pack/unpack requires '->' and destination type");
5450 }
5451
5452 if (!isMemRef)
5453 resultType = destType;
5454
5455 if (parser.resolveOperand(source, sourceType, result.operands) ||
5456 parser.resolveOperand(dest, destType, result.operands))
5457 return failure();
5458
5459 if (!paddingValue.empty() &&
5460 parser.resolveOperands(paddingValue, paddingValueType[0],
5461 result.operands))
5462 return failure();
5463
5464 if (!dynamicTiles.empty() &&
5465 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
5466 result.operands))
5467 return failure();
5468
5469 result.addAttribute("static_inner_tiles",
5470 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
5471 result.addAttribute("inner_dims_pos", innerDimsPos);
5472 if (outerDimsPerm)
5473 result.addAttribute("outer_dims_perm", outerDimsPerm);
5474
5475 SmallVector<int32_t> segmentSizes = {
5476 1, 1, static_cast<int32_t>(paddingValue.size()),
5477 static_cast<int32_t>(dynamicTiles.size())};
5478 result.addAttribute("operandSegmentSizes",
5479 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
5480
5481 if (!isMemRef)
5482 result.addTypes(resultType);
5483
5484 return success();
5485}
5486
5487void PackOp::print(OpAsmPrinter &p) {
5488 p << " " << getSource();
5489
5490 if (getPaddingValue()) {
5491 p << " padding_value(" << getPaddingValue() << " : "
5492 << getPaddingValue().getType() << ")";
5493 }
5494
5495 if (!getOuterDimsPerm().empty()) {
5496 p << " outer_dims_perm = [";
5497 llvm::interleaveComma(getOuterDimsPerm(), p);
5498 p << "]";
5499 }
5500
5501 p << " inner_dims_pos = [";
5502 llvm::interleaveComma(getInnerDimsPos(), p);
5503 p << "]";
5504
5505 p << " inner_tiles = ";
5506 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
5507
5508 p << " into " << getDest();
5509
5510 p.printOptionalAttrDict((*this)->getAttrs(),
5511 {"static_inner_tiles", "inner_dims_pos",
5512 "outer_dims_perm", "operandSegmentSizes"});
5513
5514 p << " : " << getSource().getType();
5515 p << " -> " << getDest().getType();
5516}
5517
5518void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5519 Value dest, ArrayRef<int64_t> innerDimsPos,
5520 ArrayRef<OpFoldResult> innerTiles,
5521 std::optional<Value> paddingValue,
5522 ArrayRef<int64_t> outerDimsPerm) {
5523 assert(innerDimsPos.size() == innerTiles.size() &&
5524 "number of tile sizes specified must match the specified number of "
5525 "original dimensions to be tiled");
5526 SmallVector<int64_t> staticTileSizes;
5527 SmallVector<Value> dynamicTileSizes;
5528 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5529 build(builder, state, dest.getType(), source, dest,
5530 paddingValue ? *paddingValue : nullptr,
5531 outerDimsPerm.empty() ? nullptr
5532 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5533 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5534 builder.getDenseI64ArrayAttr(staticTileSizes));
5535}
5536
5537LogicalResult
5538PackOp::reifyResultShapes(OpBuilder &builder,
5539 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5540 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5541}
5542
5543DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5544 return getDimAndTileMappingImpl(*this);
5545}
5546
5547SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5548 return getMixedTilesImpl(*this);
5549}
5550
5551SmallVector<int64_t> PackOp::getStaticTiles() {
5552 return getStaticTilesImpl(*this);
5553}
5554
5555ArrayRef<int64_t> PackOp::getAllOuterDims() {
5556 ShapedType inputType = getSourceType();
5557 int64_t inputRank = inputType.getRank();
5558 return getDestType().getShape().take_front(inputRank);
5559}
5560
5561SmallVector<int64_t> PackOp::getTiledOuterDims() {
5562 auto innerDimsPos = getInnerDimsPos();
5563 SmallVector<int64_t> outerDims(getAllOuterDims());
5564 SmallVector<int64_t> res;
5565
5566 // Recover the original order of the outer dims.
5567 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5568 invertPermutationVector(outerDimPermInv);
5569 if (!outerDimPermInv.empty())
5570 applyPermutationToVector(outerDims, outerDimPermInv);
5571
5572 // Collect the outer dims corresponding to the tilled inner dims.
5573 for (auto index : innerDimsPos)
5574 res.push_back(outerDims[index]);
5575
5576 return res;
5577}
5578
5579bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5580 ArrayRef<int64_t> innerDimsPos,
5581 ArrayRef<int64_t> outputShape,
5582 ArrayRef<int64_t> outerDimsPerm,
5583 ArrayRef<OpFoldResult> innerTiles) {
5584 SmallVector<int64_t> outputTileSizes(
5585 outputShape.take_front(inputShape.size()));
5586 if (!outerDimsPerm.empty()) {
5587 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5588 "expected output and outer_dims_perm to have same size");
5589 applyPermutationToVector(outputTileSizes,
5590 invertPermutationVector(outerDimsPerm));
5591 }
5592 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5593 if (ShapedType::isDynamic(inputShape[pos]))
5594 continue;
5595 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5596 if (!constantTile) {
5597 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5598 (inputShape[pos] % outputTileSizes[pos] != 0))
5599 return true;
5600 } else {
5601 assert(*constantTile != 0 && "static tile size can't be zero");
5602 if (inputShape[pos] % (*constantTile) != 0) {
5603 return true;
5604 }
5605 }
5606 }
5607 return false;
5608}
5609
5610bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5611 ArrayRef<int64_t> innerDimsPos,
5612 ArrayRef<int64_t> outputShape,
5613 ArrayRef<int64_t> outerDimsPerm,
5614 ArrayRef<OpFoldResult> innerTiles) {
5615 SmallVector<int64_t> outputTileSizes(
5616 outputShape.take_front(inputShape.size()));
5617 if (!outerDimsPerm.empty()) {
5618 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5619 "expected output and outer_dims_perm to have same size");
5620 applyPermutationToVector(outputTileSizes,
5621 invertPermutationVector(outerDimsPerm));
5622 }
5623 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5624 if (ShapedType::isDynamic(inputShape[pos]) ||
5625 ShapedType::isDynamic(outputTileSizes[pos]))
5626 return true;
5627 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5628 if (!constantTile)
5629 return true;
5630 assert(*constantTile != 0 && "static tile size can't be zero");
5631 if (inputShape[pos] % (*constantTile) != 0)
5632 return true;
5633 }
5634 return false;
5635}
5636
5637LogicalResult PackOp::verify() {
5639 return failure();
5640
5641 // Verify padding value, and bail out if the tile does not divide the
5642 // dimension fully. In the case of dynamic tile factors or dimensions, having
5643 // a partial tile is undefined behavior.
5644 auto paddingValue = getPaddingValue();
5645 if (paddingValue &&
5646 paddingValue.getType() != getSourceType().getElementType()) {
5647 return emitOpError("expected padding_value has ")
5648 << getSourceType().getElementType()
5649 << " but got: " << paddingValue.getType();
5650 }
5651
5652 if (!paddingValue &&
5653 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5654 getDestType().getShape(), getOuterDimsPerm(),
5655 getMixedTiles())) {
5656 return emitOpError(
5657 "invalid tile factor or output size provided. Only full tiles are "
5658 "supported when padding_value is not set");
5659 }
5660 return success();
5661}
5662
5663/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5664/// Value's to kDynamic, even if they are arith.constant values.
5665static SmallVector<int64_t>
5668 for (auto o : ofrs) {
5669 // Have to do this first, as getConstantIntValue special-cases constants.
5670 if (llvm::dyn_cast_if_present<Value>(o))
5671 result.push_back(ShapedType::kDynamic);
5672 else
5673 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5674 }
5675 return result;
5676}
5677
5678SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
5679 ArrayRef<int64_t> innerTileSizes,
5680 ArrayRef<int64_t> innerDimsPos,
5681 ArrayRef<int64_t> outerDimsPerm) {
5682 SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
5683 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5684 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5685 continue;
5686 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5687 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5688 continue;
5689 }
5690 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5691 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5692 }
5693
5694 // Swap tile loops if outer_dims_perm is available.
5695 if (!outerDimsPerm.empty())
5696 applyPermutationToVector(resultShape, outerDimsPerm);
5697
5698 // Append the inner tile dimensions.
5699 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5700 return resultShape;
5701}
5702
5703SmallVector<OpFoldResult> PackOp::getResultShape(
5704 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5705 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5706 ArrayRef<int64_t> outerDimsPerm) {
5707 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5708
5709 AffineExpr s0, s1;
5710 bindSymbols(builder.getContext(), s0, s1);
5711 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5712 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5713 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5714 builder, loc, ceilDivExpr,
5715 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5716 }
5717 if (!outerDimsPerm.empty())
5718 applyPermutationToVector(resultDims, outerDimsPerm);
5719 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5720
5721 SmallVector<int64_t> resultTypeShape =
5722 inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
5723 asShapeWithAnyValueAsDynamic(innerTileSizes),
5724 innerDimsPos, outerDimsPerm);
5725
5726 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5727 // result type shape says it's a dynamic dim. This is needed as callers may
5728 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5729 // dynamic dims returned by that.
5730 for (unsigned i = 0; i < resultDims.size(); ++i) {
5731 if (ShapedType::isStatic(resultTypeShape[i]))
5732 continue;
5733 resultDims[i] =
5734 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5735 }
5736
5737 return resultDims;
5738}
5739
5740RankedTensorType PackOp::inferPackedTensorType(
5741 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5742 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5743 SmallVector<int64_t> resultShape = inferPackedShape(
5744 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5745 return RankedTensorType::get(resultShape, sourceType.getElementType());
5746}
5747
5748MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5749 ArrayRef<int64_t> innerTileSizes,
5750 ArrayRef<int64_t> innerDimsPos,
5751 ArrayRef<int64_t> outerDimsPerm) {
5752 SmallVector<int64_t> resultShape = inferPackedShape(
5753 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5754 return MemRefType::get(resultShape, sourceType.getElementType());
5755}
5756
5757Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5758 ArrayRef<OpFoldResult> innerTileSizes,
5759 ArrayRef<int64_t> innerDimsPos,
5760 ArrayRef<int64_t> outerDimsPerm) {
5761 AffineExpr dim0, dim1;
5762 bindDims(b.getContext(), dim0, dim1);
5763 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5764 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5765 {v1, v2});
5766 };
5767
5768 SmallVector<OpFoldResult> mixedSizes;
5769 for (auto [index, value] : llvm::enumerate(
5770 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5771 if (ShapedType::isDynamic(value))
5772 mixedSizes.push_back(
5773 tensor::DimOp::create(b, loc, source, index).getResult());
5774 else
5775 mixedSizes.push_back(b.getIndexAttr(value));
5776 }
5777 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5778 int64_t dimPos = std::get<0>(it);
5779 OpFoldResult tileSize = std::get<1>(it);
5780 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5781 }
5782 if (!outerDimsPerm.empty())
5783 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5784
5785 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5786 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5787 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5788}
5789
5790PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5791 ArrayRef<int64_t> innerPermutation,
5792 ArrayRef<int64_t> outerPermutation) {
5793 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5794 *this, innerPermutation, outerPermutation);
5795 Value transposedDest =
5796 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5797 metadata.innerDimsPos, metadata.outerDimsPerm);
5798 return PackOp::create(b, loc, getSource(), transposedDest,
5799 metadata.innerDimsPos, metadata.innerTiles,
5800 getPaddingValue(), metadata.outerDimsPerm);
5801}
5802
5803template <typename OpTy>
5806 &effects) {
5807 // No memory effects for pure tensor semantics
5808 if (op.hasPureTensorSemantics())
5809 return;
5810
5811 for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5812 if (!llvm::isa<MemRefType>(opOperand.get().getType()))
5813 continue;
5814
5815 if (&opOperand == &op.getSourceMutable()) {
5816 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5817 /*effectOnFullRegion=*/true,
5819 } else if (&opOperand == &op.getDestMutable()) {
5820 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5821 /*effectOnFullRegion=*/true,
5823 effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
5824 /*effectOnFullRegion=*/true,
5826 }
5827 }
5828}
5829
5830void PackOp::getEffects(
5832 &effects) {
5833 getPackUnPackEffectsImpl(*this, effects);
5834}
5835
5836void UnPackOp::getEffects(
5838 &effects) {
5839 getPackUnPackEffectsImpl(*this, effects);
5840}
5841
5842/// Returns true if the tiles and the tiled dims are constant.
5843template <typename OpTy>
5845 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5846 "applies to only pack or unpack operations");
5847 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5848 ? op.getDestType()
5849 : op.getSourceType();
5850 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5851 for (auto [dimDest, tile] : llvm::zip(
5852 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5853 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5854 if (!constTileSize || ShapedType::isDynamic(dimDest))
5855 return false;
5856 }
5857 return true;
5858}
5859
5860Speculation::Speculatability PackOp::getSpeculatability() {
5861 if (!hasPureTensorSemantics())
5863 if (getPaddingValue())
5865
5866 // The verifier rejects already operations if we can statically prove that the
5867 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5868 // to have constant tiles and tiled inner dimensions.
5871
5873}
5874
5875// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5876// dimensions for pack and unpack.
5877static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5878 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5879 return false;
5880 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5881 return true;
5882 // Outer dims permutation is optional.
5883 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5884 // identity permutation.
5885 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5886 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5887}
5888
5889// Return true if pack and unpack have the same tiles.
5890// Same SSA values or same integer constants.
5891static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5892 auto packTiles = packOp.getMixedTiles();
5893 auto unPackTiles = unPackOp.getMixedTiles();
5894 if (packTiles.size() != unPackTiles.size())
5895 return false;
5896 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5897 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5898 return false;
5899 }
5900 return true;
5901}
5902
5903/// Returns true if the pack op does not need a padding value.
5904static bool paddingIsNotNeeded(PackOp op) {
5905 auto srcType = op.getSourceType();
5906 auto innerDimsPos = op.getInnerDimsPos();
5907 auto innerTiles = op.getStaticInnerTiles();
5908 if (ShapedType::isDynamicShape(innerTiles))
5909 return false;
5910 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5911 if (srcType.isDynamicDim(pos) && tileSize != 1)
5912 return false;
5913 }
5914 return !PackOp::requirePaddingValue(
5915 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5916 op.getOuterDimsPerm(), op.getMixedTiles());
5917}
5918
5919/// Returns true if the `srcShape` or `destShape` is different from the one in
5920/// `packOp` and populates each with the inferred static shape.
5921static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5922 SmallVectorImpl<int64_t> &destShape) {
5923 bool changeNeeded = false;
5924 srcShape.assign(packOp.getSourceType().getShape().begin(),
5925 packOp.getSourceType().getShape().end());
5926 destShape.assign(packOp.getDestType().getShape().begin(),
5927 packOp.getDestType().getShape().end());
5928 llvm::SmallSetVector<int64_t, 4> innerDims;
5929 innerDims.insert_range(packOp.getInnerDimsPos());
5930 SmallVector<int64_t> inverseOuterDimsPerm;
5931 if (!packOp.getOuterDimsPerm().empty())
5932 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5933 int srcRank = packOp.getSourceRank();
5934 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5935 if (innerDims.contains(i))
5936 continue;
5937 int64_t srcPos = i;
5938 int64_t destPos = i;
5939 if (!inverseOuterDimsPerm.empty())
5940 destPos = inverseOuterDimsPerm[srcPos];
5941 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5942 ShapedType::isDynamic(destShape[destPos])) {
5943 continue;
5944 }
5945 int64_t size = srcShape[srcPos];
5946 if (ShapedType::isDynamic(size))
5947 size = destShape[destPos];
5948 srcShape[srcPos] = size;
5949 destShape[destPos] = size;
5950 changeNeeded = true;
5951 }
5952 return changeNeeded;
5953}
5954
5955LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5956 // TODO: Support Memref PackOp. Temporarily return failure.
5957 if (!packOp.hasPureTensorSemantics())
5958 return failure();
5959
5960 // Fold an pack(unpack(x)) to x.
5961 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5962 if (unPackOp.getSourceType() == packOp.getDestType() &&
5963 !packOp.getPaddingValue() &&
5964 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5965 haveSameTiles(packOp, unPackOp)) {
5966 rewriter.replaceOp(packOp, unPackOp.getSource());
5967 return success();
5968 }
5969 }
5970
5971 // Fold optional PaddingValue operand away if padding is not needed.
5972 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5973 rewriter.startOpModification(packOp);
5974 packOp.getPaddingValueMutable().clear();
5975 rewriter.finalizeOpModification(packOp);
5976 return success();
5977 }
5978
5979 // Insert tensor.cast ops if static shape inference is available..
5980 SmallVector<int64_t> srcShape, destShape;
5981 if (inferStaticShape(packOp, srcShape, destShape)) {
5982 Location loc = packOp.getLoc();
5983 Value source = packOp.getSource();
5984 if (srcShape != packOp.getSourceType().getShape()) {
5985 auto newSrcType = packOp.getSourceType().clone(srcShape);
5986 source =
5987 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5988 }
5989 Value dest = packOp.getDest();
5990 ShapedType originalResultType = packOp.getDestType();
5991 bool needUpdateDestType = (destShape != originalResultType.getShape());
5992 if (needUpdateDestType) {
5993 auto newDestType = packOp.getDestType().clone(destShape);
5994 dest =
5995 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5996 }
5997 rewriter.modifyOpInPlace(packOp, [&] {
5998 packOp.getSourceMutable().assign(source);
5999 packOp.getDestMutable().assign(dest);
6000 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
6001 });
6002 // Insert a cast if needed
6003 if (needUpdateDestType) {
6004 rewriter.setInsertionPointAfter(packOp);
6005 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
6006 packOp.getResult());
6007 rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
6008 }
6009 return success();
6010 }
6011
6012 return failure();
6013}
6014
6015template <typename PackOrUnpackOp>
6016static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
6017 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
6018 std::is_same<PackOrUnpackOp, UnPackOp>::value,
6019 "Function meant for pack/unpack");
6020 // This is a pad if packing only adds ones and we don't transpose dimensions.
6021
6022 // Check that we are not transposing any dimensions.
6023 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
6024 int64_t numPackedDims = innerDimsPos.size();
6025 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
6026 if (orderedDims != innerDimsPos) {
6027 // Dimensions don't happen in order.
6028 return false;
6029 }
6030
6031 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
6032 int64_t packedRank = packedTensorType.getRank();
6033 // At this point we know that we are taking numPackedDims outer
6034 // dimensions and pushing them all the way as the inner most dimensions.
6035 // What's left on the outer most dimensions is, in this order:
6036 // - the factor of the packed dimensions, then
6037 // - the untouched dimensions
6038 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
6039 // if all the dimensions that bubble outerward are ones.
6040 // Therefore check that all the dimensions but the numPackedDims inner most
6041 // ones are ones.
6042 return llvm::all_of(
6043 llvm::seq<int64_t>(0, packedRank - numPackedDims),
6044 [&packedShape](int64_t i) { return packedShape[i] == 1; });
6045}
6046
6047bool PackOp::isLikePad() {
6048 auto packedTensorType =
6049 llvm::cast<ShapedType>((*this)->getResultTypes().front());
6050 return isLikePadUnPad(*this, packedTensorType);
6051}
6052
6053::mlir::LogicalResult
6054PackOp::fold(FoldAdaptor adaptor,
6056 if (!hasPureTensorSemantics())
6057 return failure();
6058 std::optional<Attribute> paddingValue;
6059 if (auto pad = adaptor.getPaddingValue())
6060 paddingValue = pad;
6061 if (OpFoldResult reshapedSource = reshapeConstantSource(
6062 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6063 cast<TensorType>(getDestType()), paddingValue)) {
6064 results.push_back(reshapedSource);
6065 return success();
6066 }
6067 return failure();
6068}
6069
6070/// Folds a tensor.cast op into a consuming PackOp op if the
6071/// `tensor.cast` has source that is more static than the consuming op.
6072///
6073/// Example:
6074/// ```mlir
6075/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
6076/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
6077/// ```
6078///
6079/// folds into:
6080///
6081/// ```mlir
6082/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
6083/// ```
6086
6087 LogicalResult matchAndRewrite(PackOp op,
6088 PatternRewriter &rewriter) const override {
6089 // TODO: Support Memref PackOp. Temporarily return failure.
6090 if (!op.hasPureTensorSemantics())
6091 return failure();
6092
6094 return failure();
6095
6096 SmallVector<Type> newResultTypes(op->getResultTypes());
6097 SmallVector<Value> newOperands =
6099
6100 // Get the updated mixed-tile-sizes attribute.
6101 SmallVector<OpFoldResult> newMixedTileSizes =
6102 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
6103 if (llvm::any_of(newMixedTileSizes, isZeroInteger))
6104 return failure();
6105
6106 // Clone op.
6107 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6108 // this point. However, in practice, we use them for things that we'd like
6109 // to preserve. Implement a better abstraction.
6110 PackOp newOp =
6111 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6112 op.getInnerDimsPos(), newMixedTileSizes,
6113 op.getPaddingValue(), op.getOuterDimsPerm());
6114 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6115
6116 // Replace op.
6117 Value oldResult = op.getResult();
6118 Value newResult = newOp.getResult();
6120 (newResult.getType() != oldResult.getType())
6121 ? tensor::CastOp::create(rewriter, op->getLoc(),
6122 oldResult.getType(), newResult)
6123 : newResult;
6124
6125 rewriter.replaceOp(op, {replacement});
6126
6127 return success();
6128 }
6129};
6130
6131//===----------------------------------------------------------------------===//
6132// UnPackOp
6133//===----------------------------------------------------------------------===//
6134
6135void UnPackOp::getAsmResultNames(
6136 function_ref<void(Value, StringRef)> setNameFn) {
6137 if (!getResults().empty())
6138 setNameFn(getResult(), "unpack");
6139}
6140
6141// Custom parser for UnPackOp that handles the memref/tensor case distinction
6142ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
6143 OpAsmParser::UnresolvedOperand source, dest;
6145 SmallVector<int64_t> staticTiles;
6146 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
6147 Type sourceType, destType, resultType;
6148
6149 if (parser.parseOperand(source))
6150 return failure();
6151
6152 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
6153 if (parser.parseEqual())
6154 return failure();
6155
6156 SmallVector<int64_t> outerDimsPermVec;
6158 int64_t value;
6159 if (parser.parseInteger(value))
6160 return failure();
6161 outerDimsPermVec.push_back(value);
6162 return success();
6163 }))
6164 return failure();
6165 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
6166 }
6167
6168 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
6169 return failure();
6170
6171 SmallVector<int64_t> innerDimsPosVec;
6173 int64_t value;
6174 if (parser.parseInteger(value))
6175 return failure();
6176 innerDimsPosVec.push_back(value);
6177 return success();
6178 }))
6179 return failure();
6180 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
6181
6182 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
6183 return failure();
6184
6185 DenseI64ArrayAttr staticTilesAttr;
6186 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
6187 return failure();
6188 for (auto val : staticTilesAttr.asArrayRef())
6189 staticTiles.push_back(val);
6190
6191 if (parser.parseKeyword("into") || parser.parseOperand(dest))
6192 return failure();
6193
6194 if (parser.parseOptionalAttrDict(result.attributes))
6195 return failure();
6196
6197 if (parser.parseColon() || parser.parseType(sourceType))
6198 return failure();
6199
6200 bool hasArrow = succeeded(parser.parseOptionalArrow());
6201 if (hasArrow) {
6202 if (parser.parseType(destType))
6203 return failure();
6204 }
6205
6206 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6207 if (!hasArrow) {
6208 return parser.emitError(parser.getCurrentLocation(),
6209 "pack/unpack requires '->' and destination type");
6210 }
6211
6212 if (!isMemRef)
6213 resultType = destType;
6214
6215 if (parser.resolveOperand(source, sourceType, result.operands) ||
6216 parser.resolveOperand(dest, destType, result.operands))
6217 return failure();
6218
6219 if (!dynamicTiles.empty() &&
6220 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
6221 result.operands))
6222 return failure();
6223
6224 result.addAttribute("static_inner_tiles",
6225 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
6226 result.addAttribute("inner_dims_pos", innerDimsPos);
6227 if (outerDimsPerm)
6228 result.addAttribute("outer_dims_perm", outerDimsPerm);
6229
6230 SmallVector<int32_t> segmentSizes = {
6231 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
6232 result.addAttribute("operandSegmentSizes",
6233 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
6234
6235 if (!isMemRef)
6236 result.addTypes(resultType);
6237
6238 return success();
6239}
6240
6241void UnPackOp::print(OpAsmPrinter &p) {
6242 p << " " << getSource();
6243
6244 if (!getOuterDimsPerm().empty()) {
6245 p << " outer_dims_perm = [";
6246 llvm::interleaveComma(getOuterDimsPerm(), p);
6247 p << "]";
6248 }
6249
6250 p << " inner_dims_pos = [";
6251 llvm::interleaveComma(getInnerDimsPos(), p);
6252 p << "]";
6253
6254 p << " inner_tiles = ";
6255 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
6256
6257 p << " into " << getDest();
6258
6259 p.printOptionalAttrDict((*this)->getAttrs(),
6260 {"static_inner_tiles", "inner_dims_pos",
6261 "outer_dims_perm", "operandSegmentSizes"});
6262
6263 p << " : " << getSource().getType();
6264 p << " -> " << getDest().getType();
6265}
6266
6267LogicalResult
6268UnPackOp::reifyResultShapes(OpBuilder &builder,
6269 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6270 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
6271}
6272
6273DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
6274 return getDimAndTileMappingImpl(*this);
6275}
6276
6277SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6278 return getMixedTilesImpl(*this);
6279}
6280
6281SmallVector<int64_t> UnPackOp::getStaticTiles() {
6282 return getStaticTilesImpl(*this);
6283}
6284
6285ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6286 ShapedType destType = getDestType();
6287 int64_t destRank = destType.getRank();
6288 return getSourceType().getShape().take_front(destRank);
6289}
6290
6291SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6292 auto innerDimsPos = getInnerDimsPos();
6293 SmallVector<int64_t> outerDims(getAllOuterDims());
6294 SmallVector<int64_t> res;
6295
6296 // Recover the original order of the outer dims.
6297 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6298 invertPermutationVector(outerDimPermInv);
6299 if (!outerDimPermInv.empty())
6300 applyPermutationToVector(outerDims, outerDimPermInv);
6301
6302 // Collect the outer dims corresponding to the tilled inner dims.
6303 for (auto index : innerDimsPos)
6304 res.push_back(outerDims[index]);
6305
6306 return res;
6307}
6308
6309LogicalResult UnPackOp::verify() {
6310 return commonVerifierPackAndUnPackOp(*this);
6311}
6312
6313Speculation::Speculatability UnPackOp::getSpeculatability() {
6314 if (!hasPureTensorSemantics())
6316 // See PackOp::getSpeculatability.
6319
6321}
6322
6323void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6324 Value dest, ArrayRef<int64_t> innerDimsPos,
6325 ArrayRef<OpFoldResult> innerTiles,
6326 ArrayRef<int64_t> outerDimsPerm) {
6327 assert(innerDimsPos.size() == innerTiles.size() &&
6328 "number of tile sizes specified must match the specified number of "
6329 "original dimensions to be tiled");
6330 SmallVector<int64_t> staticTileSizes;
6331 SmallVector<Value> dynamicTileSizes;
6332 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
6333 build(builder, state, dest.getType(), source, dest,
6334 outerDimsPerm.empty() ? nullptr
6335 : builder.getDenseI64ArrayAttr(outerDimsPerm),
6336 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
6337 builder.getDenseI64ArrayAttr(staticTileSizes));
6338}
6339
6340Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
6341 Value source,
6342 ArrayRef<OpFoldResult> innerTileSizes,
6343 ArrayRef<int64_t> innerDimsPos,
6344 ArrayRef<int64_t> outerDimsPerm) {
6345 AffineExpr sym0, sym1;
6346 bindSymbols(b.getContext(), sym0, sym1);
6347 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6348 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
6349 };
6350
6351 SmallVector<OpFoldResult> mixedSizes;
6352 auto srcType = llvm::cast<RankedTensorType>(source.getType());
6353 for (auto i :
6354 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6355 if (srcType.isDynamicDim(i))
6356 mixedSizes.push_back(
6357 tensor::DimOp::create(b, loc, source, i).getResult());
6358 else
6359 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
6360 }
6361 if (!outerDimsPerm.empty()) {
6363 mixedSizes, invertPermutationVector(outerDimsPerm));
6364 }
6365
6366 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6367 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6368
6369 auto elemType = srcType.getElementType();
6370 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
6371}
6372
6373UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
6374 Value transposedSource,
6375 ArrayRef<int64_t> innerPermutation,
6376 ArrayRef<int64_t> outerPermutation) {
6377 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
6378 *this, innerPermutation, outerPermutation);
6379 return UnPackOp::create(b, loc, transposedSource, getDest(),
6380 metadata.innerDimsPos, metadata.innerTiles,
6381 metadata.outerDimsPerm);
6382}
6383
6384/// Returns true if the `srcShape` or `destShape` is different from the one in
6385/// `op` and populates each with the inferred static shape.
6386static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
6387 SmallVectorImpl<int64_t> &destShape) {
6388 bool changeNeeded = false;
6389 srcShape.assign(op.getSourceType().getShape().begin(),
6390 op.getSourceType().getShape().end());
6391 destShape.assign(op.getDestType().getShape().begin(),
6392 op.getDestType().getShape().end());
6393 llvm::SmallSetVector<int64_t, 4> innerDims;
6394 innerDims.insert_range(op.getInnerDimsPos());
6395 SmallVector<int64_t> inverseOuterDimsPerm;
6396 if (!op.getOuterDimsPerm().empty())
6397 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
6398 int destRank = op.getDestRank();
6399 for (auto i : llvm::seq<int64_t>(0, destRank)) {
6400 if (innerDims.contains(i))
6401 continue;
6402 int64_t srcPos = i;
6403 int64_t destPos = i;
6404 if (!inverseOuterDimsPerm.empty())
6405 srcPos = inverseOuterDimsPerm[destPos];
6406 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6407 ShapedType::isDynamic(destShape[destPos])) {
6408 continue;
6409 }
6410 int64_t size = srcShape[srcPos];
6411 if (ShapedType::isDynamic(size))
6412 size = destShape[destPos];
6413 srcShape[srcPos] = size;
6414 destShape[destPos] = size;
6415 changeNeeded = true;
6416 }
6417 return changeNeeded;
6418}
6419
6420LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6421 PatternRewriter &rewriter) {
6422 // TODO: Support Memref UnPackOp. Temporarily return failure.
6423 if (!unPackOp.hasPureTensorSemantics())
6424 return failure();
6425
6426 /// unpack(pack(x)) -> x
6427 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6428 if (packOp.getSourceType() != unPackOp.getDestType())
6429 return failure();
6430 if (packOp.getPaddingValue() ||
6431 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
6432 !haveSameTiles(packOp, unPackOp))
6433 return failure();
6434 rewriter.replaceOp(unPackOp, packOp.getSource());
6435 return success();
6436 }
6437 /// unpack(destinationStyleOp(x)) -> unpack(x)
6438 if (auto dstStyleOp =
6439 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6440 auto destValue = cast<OpResult>(unPackOp.getDest());
6441 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6442 rewriter.modifyOpInPlace(unPackOp,
6443 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6444 return success();
6445 }
6446 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
6447 if (unPackOp->hasOneUse()) {
6448 auto extractSliceUser =
6449 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6450 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6451 OpBuilder::InsertionGuard g(rewriter);
6452 rewriter.setInsertionPoint(unPackOp);
6453 auto newDest = tensor::ExtractSliceOp::create(
6454 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6455 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6456 extractSliceUser.getMixedStrides());
6457 rewriter.modifyOpInPlace(unPackOp, [&]() {
6458 unPackOp.setDpsInitOperand(0, newDest);
6459 unPackOp.getResult().setType(newDest.getType());
6460 });
6461 rewriter.replaceOp(extractSliceUser, unPackOp);
6462 return success();
6463 }
6464 }
6465
6466 // Insert tensor.cast ops if static shape inference is available..
6467 SmallVector<int64_t> srcShape, destShape;
6468 if (inferStaticShape(unPackOp, srcShape, destShape)) {
6469 Location loc = unPackOp.getLoc();
6470 Value source = unPackOp.getSource();
6471 if (srcShape != unPackOp.getSourceType().getShape()) {
6472 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6473 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6474 unPackOp.getSource());
6475 }
6476 Value dest = unPackOp.getDest();
6477 if (destShape != unPackOp.getDestType().getShape()) {
6478 auto newDestType = unPackOp.getDestType().clone(destShape);
6479 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6480 unPackOp.getDest());
6481 }
6482 UnPackOp newOp = UnPackOp::create(
6483 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6484 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6485 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6486 unPackOp, unPackOp.getResult().getType(), newOp.getResult());
6487 return success();
6488 }
6489
6490 return failure();
6491}
6492
6493bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6494 // Rank-reduced folding is not supported.
6495 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6496 return false;
6497 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6498 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6499 return false;
6500 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6501 SmallVector<int64_t> outerShapeWithoutTranspose =
6503 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
6504 for (auto [pos, tileSize] :
6505 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6506 areOuterDimsTiled[pos] = true;
6507 if (unpackedTypeAfterFold.isDynamicDim(pos))
6508 return false;
6509 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6510 return false;
6511 if (ShapedType::isDynamic(tileSize))
6512 return false;
6513 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6514 unpackedTypeAfterFold.getDimSize(pos);
6515 if (paddingSize >= tileSize)
6516 return false;
6517 }
6518 // extract_slice must not affect dimensions that are not being unpacked
6519 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6520 if (areOuterDimsTiled[pos])
6521 continue;
6522 int64_t dim = outerShapeWithoutTranspose[pos];
6523 if (ShapedType::isDynamic(dim))
6524 return false;
6525 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6526 return false;
6527 }
6528 return true;
6529}
6530
6531bool UnPackOp::isLikeUnPad() {
6532 ShapedType packedTensorType = getSourceType();
6533 return isLikePadUnPad(*this, packedTensorType);
6534}
6535
6536::mlir::LogicalResult
6537UnPackOp::fold(FoldAdaptor adaptor,
6538 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6539 // TODO: Support Memref UnPackOp. Temporarily return failure.
6540 if (!hasPureTensorSemantics())
6541 return failure();
6542
6543 if (OpFoldResult reshapedSource = reshapeConstantSource(
6544 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6545 cast<TensorType>(getResult().getType()))) {
6546 results.push_back(reshapedSource);
6547 return success();
6548 }
6549 return failure();
6550}
6551
6552/// Folds a tensor.cast op into a consuming UnPackOp op if the
6553/// `tensor.cast` has source that is more static than the consuming op.
6554///
6555/// Example:
6556/// ```mlir
6557/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6558/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6559/// ```
6560///
6561/// folds into:
6562///
6563/// ```mlir
6564/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6565/// ```
6566struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6567 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6568
6569 LogicalResult matchAndRewrite(UnPackOp op,
6570 PatternRewriter &rewriter) const override {
6571 // TODO: Support Memref UnPackOp. Temporarily return failure.
6572 if (!op.hasPureTensorSemantics())
6573 return failure();
6574
6576 return failure();
6577
6578 SmallVector<Type> newResultTypes(op->getResultTypes());
6579 SmallVector<Value> newOperands =
6581 Value sourceTensor = newOperands[0];
6582
6583 // Get the updated mixed-tile-sizes attribute.
6585 rewriter, sourceTensor.getType(), op.getMixedTiles());
6586
6587 // Clone op.
6588 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6589 // this point. However, in practice, we use them for things that we'd like
6590 // to preserve. Implement a better abstraction.
6591 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6592 newOperands[1], op.getInnerDimsPos(),
6593 newMixedTileSizes, op.getOuterDimsPerm());
6594 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6595
6596 // Replace op.
6597 Value oldResult = op.getResult();
6598 Value newResult = newOp.getResult();
6600 (newResult.getType() != oldResult.getType())
6601 ? tensor::CastOp::create(rewriter, op->getLoc(),
6602 oldResult.getType(), newResult)
6603 : newResult;
6604
6605 rewriter.replaceOp(op, {replacement});
6606
6607 return success();
6608 }
6609};
6610
6611//===----------------------------------------------------------------------===//
6612// BatchReduceMatmulOp
6613//===----------------------------------------------------------------------===//
6614SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6616 utils::IteratorType::reduction, utils::IteratorType::parallel,
6617 utils::IteratorType::parallel, utils::IteratorType::reduction};
6618}
6619
6620SmallVector<AffineMap>
6621BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6622 AffineExpr d0, d1, d2, d3;
6623 SmallVector<AffineMap> indexingMaps;
6624 bindDims(context, d0, d1, d2, d3);
6625 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6626 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6627 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6628 return indexingMaps;
6629}
6630
6631bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6632 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6633 if (!maps)
6634 return false;
6635 if (maps.size() != 3)
6636 return false;
6637 auto positions = getAffineResultPositions(maps);
6638 if (failed(positions))
6639 return false;
6640 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6641 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6642 (*positions)[2] == SmallVector<int64_t>{1, 2};
6643}
6644unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6645
6646std::string BatchReduceMatmulOp::getLibraryCallName() {
6647 return generateLibraryCallName(getOperation());
6648}
6649
6650/// Check if the op has broadcast and/or transpose semantic. Returns true if
6651/// the user defined indexing maps are not equal to default map.
6652bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6653 SmallVector<AffineMap, 3> defaultMaps =
6654 getDefaultIndexingMaps(this->getContext());
6655 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6656 return defaultMaps != explicitMaps;
6657}
6658
6659/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6660/// broadcast map must include K dimension.
6661/// TODO: Strict inclusion of K dimension in the broadcast map is not
6662/// necessary for both input matrices simultaneously. We can relax this
6663/// condition to have K dimension for one input matrix map and infer the K
6664/// dimension for other input matrix map from the one already having K
6665/// dimension.
6666bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6667 bool isLHS) {
6668 assert(bcastMap.getNumResults() < 3 &&
6669 "Expected less than 3 result dim expr.");
6670 bool isValid = false;
6671 enum Indices { batchPos, mPos, nPos, kPos };
6672 if (bcastMap.getNumResults() == 1) {
6673 AffineExpr expr = bcastMap.getResult(0);
6674 isValid = expr.isFunctionOfDim(kPos);
6675 } else if (bcastMap.getNumResults() == 2) {
6676 AffineExpr expr0 = bcastMap.getResult(0);
6677 AffineExpr expr1 = bcastMap.getResult(1);
6678 isValid =
6679 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6680 expr0.isFunctionOfDim(mPos)) &&
6681 expr1.isFunctionOfDim(kPos))
6682 : ((expr0.isFunctionOfDim(batchPos) &&
6683 expr1.isFunctionOfDim(kPos)) ||
6684 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6685 }
6686 return isValid;
6687}
6688
6689void BatchReduceMatmulOp::regionBuilder(
6690 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
6691 function_ref<InFlightDiagnostic()> emitError) {
6692 if (emitError && block.getNumArguments() != 3) {
6693 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6694 << block.getNumArguments();
6695 return;
6696 }
6697 assert(block.getNumArguments() == 3 &&
6698 "BatchReduceMatmulOp regionBuilder expects 3 args");
6699 RegionBuilderHelper helper(b, block);
6700 SmallVector<Value> yields;
6701
6702 auto toType = block.getArgument(2).getType();
6703 Value castValA =
6704 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6705 Value castValB =
6706 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6707 Value mulVal =
6708 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
6709 if (!castValA || !castValB || !mulVal)
6710 return;
6711 Value addVal =
6712 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6713 if (!addVal)
6714 return;
6715 yields.push_back(addVal);
6716 helper.yieldOutputs(yields);
6717}
6718
6719ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6720 OperationState &result) {
6721 SmallVector<Attribute, 3> indexingMapsAttr;
6722 Attribute mapAttr;
6723 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6724 if (parser.parseEqual())
6725 return failure();
6726 if (parser.parseLSquare())
6727 return failure();
6728
6729 do {
6730 if (parser.parseAttribute(mapAttr))
6731 return failure();
6732 if (!isa<AffineMapAttr>(mapAttr)) {
6733 return parser.emitError(parser.getCurrentLocation(),
6734 "expected affine map attribute");
6735 }
6736 indexingMapsAttr.push_back(mapAttr);
6737
6738 if (parser.parseOptionalComma())
6739 break;
6740 } while (true);
6741
6742 if (parser.parseRSquare())
6743 return failure();
6744 }
6745 // Initialize indexingMaps, if not supplied explicitly.
6746 if (indexingMapsAttr.empty()) {
6747 indexingMapsAttr = llvm::map_to_vector(
6748 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6749 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6750 }
6751 result.addAttribute("indexing_maps",
6752 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6753 return ::parseNamedStructuredOp(parser, result,
6754 BatchReduceMatmulOp::getNumRegionArgs(),
6755 BatchReduceMatmulOp::getRegionBuilder());
6756}
6757
6758void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6759 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6760 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6761 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6762
6763 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6764 p << " indexing_maps = [";
6765 llvm::interleaveComma(getIndexingMaps(), p,
6766 [&](Attribute attr) { p.printAttribute(attr); });
6767 p << "]";
6768 }
6769
6770 SmallVector<StringRef, 3> elidedAttrs = {
6771 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6772 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6773 elidedAttrs);
6774}
6775
6776/// Verify the user defined indexing maps.
6777LogicalResult BatchReduceMatmulOp::verify() {
6778 // Verification of pure batch_reduce_matmul is handled by
6779 // verifyStructuredOpInterface().
6780 if (!hasUserDefinedMaps())
6781 return success();
6782
6783 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6785 return failure();
6786 }
6787 return success();
6788}
6789LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6790 SmallVectorImpl<OpFoldResult> &) {
6791 return memref::foldMemRefCast(*this);
6792}
6793void BatchReduceMatmulOp::getEffects(
6794 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6795 &effects) {
6796 if (hasPureTensorSemantics())
6797 return;
6798 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6799}
6800
6801Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6802 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6803}
6804
6805} // namespace linalg
6806} // namespace mlir
6807
6808//===----------------------------------------------------------------------===//
6809// LinalgDialect
6810//===----------------------------------------------------------------------===//
6811
6812void LinalgDialect::getCanonicalizationPatterns(
6813 RewritePatternSet &results) const {
6814 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6815 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6816}
6817
6818Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6819 Attribute value, Type type,
6820 Location loc) {
6821 return arith::ConstantOp::materialize(builder, value, type, loc);
6822}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static std::optional< TypedAttr > getScalarConstantAttrFromDenseSplat(Value input)
Definition LinalgOps.cpp:95
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:60
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:392
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:369
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:323
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:462
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:559
result_iterator result_begin()
Definition Operation.h:438
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
unsigned getNumOperands()
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_type_range getOperandTypes()
Definition Operation.h:422
result_iterator result_end()
Definition Operation.h:439
result_type_range getResultTypes()
Definition Operation.h:453
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, ArrayRef< OpFoldResult > mixedTiles)
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:241
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a broadcast of a dense splat constant into a dense splat constant of the broadcast output sha...
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Rewrite a transpose of a dense splat constant into a dense splat constant of the transposed output sh...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override