MLIR 23.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <numeric>
34#include <type_traits>
35
36using namespace mlir;
37using namespace mlir::tosa;
38
39#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// Tosa dialect interface includes.
44//===----------------------------------------------------------------------===//
45
46#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
49#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
50
51namespace {
52#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
53
54//===----------------------------------------------------------------------===//
55// Dialect Function Inliner Interface.
56//===----------------------------------------------------------------------===//
57struct TosaInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks.
62 //===--------------------------------------------------------------------===//
63
64 /// All operations can be inlined by default.
65 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
66 IRMapping &map) const final {
67 return true;
68 }
69
70 /// All regions with If and While parent operators can be inlined.
71 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
72 IRMapping &map) const final {
73 return (isa<tosa::IfOp>(dest->getParentOp()) ||
74 isa<tosa::WhileOp>(dest->getParentOp()));
75 }
76};
77
78/// This class implements the bytecode interface for the Tosa dialect.
79struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
80 TosaDialectBytecodeInterface(Dialect *dialect)
81 : BytecodeDialectInterface(dialect) {}
82
83 //===--------------------------------------------------------------------===//
84 // Attributes
85
86 Attribute readAttribute(DialectBytecodeReader &reader) const override {
87 return ::readAttribute(getContext(), reader);
88 }
89
90 LogicalResult writeAttribute(Attribute attr,
91 DialectBytecodeWriter &writer) const override {
92 return ::writeAttribute(attr, writer);
93 }
94
95 //===--------------------------------------------------------------------===//
96 // Types
97
98 Type readType(DialectBytecodeReader &reader) const override {
99 return ::readType(getContext(), reader);
100 }
101
102 LogicalResult writeType(Type type,
103 DialectBytecodeWriter &writer) const override {
104 return ::writeType(type, writer);
105 }
106
107 void writeVersion(DialectBytecodeWriter &writer) const final {
108 // TODO: Populate.
109 }
110
111 std::unique_ptr<DialectVersion>
112 readVersion(DialectBytecodeReader &reader) const final {
113 // TODO: Populate
114 reader.emitError("Dialect does not support versioning");
115 return nullptr;
116 }
117
118 LogicalResult upgradeFromVersion(Operation *topLevelOp,
119 const DialectVersion &version) const final {
120 return success();
121 }
122};
123
124} // namespace
125
126//===----------------------------------------------------------------------===//
127// TOSA control flow support.
128//===----------------------------------------------------------------------===//
129
130/// Returns the while loop body.
131SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
132 return {&getBodyGraph()};
133}
134
135//===----------------------------------------------------------------------===//
136// TOSA variable operator support.
137//===----------------------------------------------------------------------===//
138
140 return map_to_vector(shape, [](int64_t dim) {
141 return dim == -1 ? ShapedType::kDynamic : dim;
142 });
143}
144
145// returns type of variable op
146RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
147 Type elementType = variableOp.getType();
148 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
149 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
150 return RankedTensorType::get(shape, elementType);
151}
152
153//===----------------------------------------------------------------------===//
154// Tosa dialect initialization.
155//===----------------------------------------------------------------------===//
156
157void TosaDialect::initialize() {
158 addTypes<
159#define GET_TYPEDEF_LIST
160#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 >();
162 addOperations<
163#define GET_OP_LIST
164#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
165 >();
166 addAttributes<
167#define GET_ATTRDEF_LIST
168#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 >();
170 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
171 declarePromisedInterfaces<
172 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
173 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
174 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
175 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
176 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
177 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
178 GreaterEqualOp, MatMulOp>();
179}
180
181Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
182 Type type, Location loc) {
183 // Tosa dialect constants only support ElementsAttr unlike standard dialect
184 // constant which supports all attributes.
185 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
186 return tosa::ConstShapeOp::create(builder, loc, type,
187 llvm::cast<DenseIntElementsAttr>(value));
188 }
189 if (llvm::isa<ElementsAttr>(value))
190 return tosa::ConstOp::create(builder, loc, type,
191 llvm::cast<ElementsAttr>(value));
192 return nullptr;
193}
194
195//===----------------------------------------------------------------------===//
196// Parsers and printers
197//===----------------------------------------------------------------------===//
198
199namespace {
200
201ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
202 DenseElementsAttr &varShapeAttr,
203 TypeAttr &typeAttr) {
204 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
205 if (!shapedType.hasRank())
206 return parser.emitError(parser.getCurrentLocation())
207 << "expected ranked type";
208
209 auto elementType = shapedType.getElementType();
210 typeAttr = TypeAttr::get(elementType);
211 ArrayRef<int64_t> shape = shapedType.getShape();
212 Builder builder(parser.getContext());
213 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
214 return success();
215 }
216 return parser.emitError(parser.getCurrentLocation())
217 << "expected shaped type";
218}
219
220} // namespace
221
222// parses the optional initial value or type for a tosa variable
223// with initial value:
224// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
225//
226// without initial value:
227// tosa.variable @name : tensor<1x8xf32>
229 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
230 Attribute &initialValueAttr) {
231 if (succeeded(parser.parseOptionalEqual())) {
232 if (failed(parser.parseAttribute(initialValueAttr))) {
233 return parser.emitError(parser.getCurrentLocation())
234 << "expected attribute";
235 }
236 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
237 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 typeAttr);
239 }
240 return parser.emitError(parser.getCurrentLocation())
241 << "expected Typed attr";
242 }
243
244 initialValueAttr = nullptr;
245 Type parsedType;
246 if (failed(parser.parseColonType(parsedType))) {
247 return parser.emitError(parser.getCurrentLocation())
248 << "expected type after colon";
249 }
250 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
251}
252
254 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
255 TypeAttr typeAttr, Attribute initialValueAttr) {
256 bool needsSpace = false;
257 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
258 auto shape =
259 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
260 Type elementType = typeAttr.getValue();
261 RankedTensorType tensorType =
262 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
263 auto tensorTypeAttr = TypeAttr::get(tensorType);
264 p << ": ";
265 p.printAttribute(tensorTypeAttr);
266 needsSpace = true; // subsequent attr value needs a space separator
267 }
268 if (initialValueAttr) {
269 if (needsSpace)
270 p << ' ';
271 p << "= ";
272 p.printAttribute(initialValueAttr);
273 }
274}
275
276namespace {
277
278// parse attributes with special handling for tosa enum attributes
279template <typename EnumType>
280ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
281 NamedAttrList &outAttrs) {
282 llvm::StringRef name;
283 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
284 return failure();
285
286 // special handling: rounding_mode accepts a *bare* RoundingMode enum
287 // keyword.
288 llvm::StringRef kw;
289 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
290 if (name == "rounding_mode" &&
291 succeeded(parser.parseOptionalKeyword(&kw))) {
292 auto sym = symbolizeRoundingMode(kw);
293 if (!sym)
294 return parser.emitError(parser.getCurrentLocation())
295 << "invalid rounding_mode value: " << kw;
296 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
297 outAttrs.push_back(NamedAttribute(name, attr));
298 return success();
299 }
300 }
301 // special handling: mode accepts a *bare* ResizeMode enum keyword.
302 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
303 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
304 auto sym = symbolizeResizeMode(kw);
305 if (!sym)
306 return parser.emitError(parser.getCurrentLocation())
307 << "invalid resize mode value: " << kw;
308 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
309 outAttrs.push_back(NamedAttribute(name, attr));
310 return success();
311 }
312 }
313 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
314 // keyword.
315 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
316 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
317 auto sym = symbolizeNanPropagationMode(kw);
318 if (!sym)
319 return parser.emitError(parser.getCurrentLocation())
320 << "invalid nan_mode value: " << kw;
321 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
322 outAttrs.push_back(NamedAttribute(name, attr));
323 return success();
324 }
325 }
326
327 // special handling: block_size accepts a *bare* BlockSizeMode enum
328 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
329 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
330 auto sym = symbolizeBlockSize(kw);
331 if (!sym)
332 return parser.emitError(parser.getCurrentLocation())
333 << "invalid block_size value: " << kw;
334 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
335 outAttrs.push_back(NamedAttribute(name, attr));
336 return success();
337 }
338 }
339
340 // Default path: parse any normal attribute literal, including fully qualified
341 // enum keyword
342 Attribute attr;
343 return parser.parseAttribute(attr, name, outAttrs);
344}
345
346template <typename EnumType>
347ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
348 // parse operands
350 if (parser.parseCommaSeparatedList(
351 [&]() { return parser.parseOperand(operands.emplace_back()); }))
352 return failure();
353
354 // Parse { attr-dict } with special handling for enum bare token
355 NamedAttrList attrs;
356 if (succeeded(parser.parseOptionalLBrace()) &&
357 failed(parser.parseOptionalRBrace())) {
358 do {
359 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
360 return failure();
361 } while (succeeded(parser.parseOptionalComma()));
362 if (parser.parseRBrace())
363 return failure();
364 }
365
366 FunctionType fnTy;
367 if (parser.parseColonType(fnTy))
368 return failure();
369
370 // Resolve operands and types
371 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
372 parser.getCurrentLocation(),
373 result.operands)))
374 return failure();
375
376 result.addTypes(fnTy.getResults());
377 result.addAttributes(attrs);
378
379 return success();
380}
381
382void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
383 parser << namedAttr.getName().strref() << " = ";
384 auto attr = namedAttr.getValue();
385 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
386 parser << roundingModeAttr.getValue();
387 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
388 parser << resizeModeAttr.getValue();
389 } else if (auto nanPropagationModeAttr =
390 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
391 parser << nanPropagationModeAttr.getValue();
392 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
393 parser << blockSizeAttr.getValue();
394 } else {
395 parser.printAttribute(attr);
396 }
397}
398
399// print with special handling for default valued NanPropagationMode attribute
400void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
401 parser << " ";
402 parser.printOperands(op->getOperands());
403
404 NamedAttrList toPrint(op->getAttrs());
405 // remove default NanPropagate attribute
406 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
407 for (auto attr : op->getAttrs()) {
408 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
409 if (nanAttr.getValue() == kDefaultNanValue) {
410 // elide from toPrint
411 toPrint.erase(attr.getName());
412 break;
413 }
414 }
415 }
416
417 if (!toPrint.empty()) {
418 parser << " {";
419 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
420 printNamedAttr(parser, namedAttr);
421 });
422 parser << "}";
423 }
424
425 parser << " : ";
426 parser.printFunctionalType(op);
427}
428
429// print with special handling for enums: RoundingMode, ResizeMode
430void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
431 parser << " ";
432 parser.printOperands(op->getOperands());
433
434 if (!op->getAttrs().empty()) {
435 parser << " {";
436 llvm::interleaveComma(op->getAttrs(), parser,
437 [&](const NamedAttribute namedAttr) {
438 printNamedAttr(parser, namedAttr);
439 });
440 parser << "}";
441 }
442
443 parser << " : ";
444 parser.printFunctionalType(op);
445}
446
447} // namespace
448
449ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
450 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
451}
452
453void RescaleOp::print(OpAsmPrinter &parser) {
454 printWithEnumHandling(parser, *this);
455}
456
457ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
458 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
459}
460
461void ApplyScaleOp::print(OpAsmPrinter &parser) {
462 printWithEnumHandling(parser, *this);
463}
464
465ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
466 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
467}
468
469void ResizeOp::print(OpAsmPrinter &parser) {
470 printWithEnumHandling(parser, *this);
471}
472
473ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
474 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
475}
476
477void ArgMaxOp::print(OpAsmPrinter &parser) {
478 printWithNanPropagationHandling(parser, *this);
479}
480
481ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
482 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
483}
484
485void MaxPool2dOp::print(OpAsmPrinter &parser) {
486 printWithNanPropagationHandling(parser, *this);
487}
488
489ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
491 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
492}
493
494void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
495 printWithNanPropagationHandling(parser, *this);
496}
497
498ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
499 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
500}
501
502void ClampOp::print(OpAsmPrinter &parser) {
503 printWithNanPropagationHandling(parser, *this);
504}
505
506ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
507 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
508}
509
510void MaximumOp::print(OpAsmPrinter &parser) {
511 printWithNanPropagationHandling(parser, *this);
512}
513
514ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
515 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
516}
517
518void MinimumOp::print(OpAsmPrinter &parser) {
519 printWithNanPropagationHandling(parser, *this);
520}
521
522ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
523 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
524}
525
526void ReduceMaxOp::print(OpAsmPrinter &parser) {
527 printWithNanPropagationHandling(parser, *this);
528}
529
530ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
531 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
532}
533
534void ReduceMinOp::print(OpAsmPrinter &parser) {
535 printWithNanPropagationHandling(parser, *this);
536}
537
538ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
540 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
541}
542
543void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
544 printWithEnumHandling(parser, *this);
545}
546
547ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
549 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
550}
551
552void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
553 printWithEnumHandling(parser, *this);
554}
555
556ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
558 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
559}
560
561void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
562 printWithEnumHandling(parser, *this);
563}
564
565ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
567 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
568}
569
570void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
571 printWithEnumHandling(parser, *this);
572}
573
574//===----------------------------------------------------------------------===//
575// Tosa utilities.
576//===----------------------------------------------------------------------===//
577
578static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
579 if (lhs % rhs != 0)
580 return std::nullopt;
581 return lhs / rhs;
582}
583
585 auto srcType = getElementTypeOrSelf(type);
586 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
587 srcType = getStorageElementTypeFromQuantized(quantType);
588 return srcType;
589}
590
594
595static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
596 Value valZp, StringRef name) {
598 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
599
600 bool bothInts =
601 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
602 bool sameBitWidth =
603 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
604
605 if (!bothInts || !sameBitWidth) {
606 return op->emitOpError()
607 << "expected " << name << " and " << name
608 << "_zp to both be integer of the same bitwidth, but got " << eType
609 << " vs. " << eZpType;
610 }
611 return success();
612}
613
614// Create a pad-const const tensor with value of `val` of required data-type
616 Value src, int32_t val) {
617 const auto srcType = getElementTypeOrSelf(src);
618 const auto srcElemType = getStorageElementTypeOrSelf(src);
619 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
620 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
621 const auto padConstAttr{
622 llvm::isa<FloatType>(srcElemType)
623 ? DenseElementsAttr::get(padConstEType,
624 builder.getFloatAttr(srcElemType, val))
625 : DenseElementsAttr::get(padConstEType,
626 builder.getIntegerAttr(srcElemType, val))};
627 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
628}
629
631 if (dyn_cast<tosa::mxint8Type>(type))
632 return 8;
633 return type.getIntOrFloatBitWidth();
634}
635
636// Update dim size if current dim is dynamic, otherwise raise an error if sizes
637// do not match
638LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
639 const int64_t newDim,
640 const StringRef operandName,
641 const StringRef dimName) {
642 if (ShapedType::isDynamic(currDim)) {
643 currDim = newDim;
644 return success();
645 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
646 return op->emitOpError("expected ")
647 << dimName << " of " << operandName << " to match size " << currDim
648 << ", got " << newDim;
649 }
650 return success();
651}
652
655 auto printDim = [&](int64_t dim) {
656 if (ShapedType::isDynamic(dim))
657 diag << "?";
658 else
659 diag << dim;
660 };
661
662 llvm::interleaveComma(shape, diag, printDim);
663}
664
665static LogicalResult
667 ArrayRef<int64_t> expectedShape,
668 StringRef outputName = "output") {
669 assert(outputType.hasRank() && "expected output type to be ranked");
670
671 if (succeeded(verifyCompatibleShape(outputType.getShape(), expectedShape)))
672 return success();
673
674 InFlightDiagnostic diag = op->emitOpError("expected ");
675 diag << outputName << " shape ";
676 printShapeToDiagnostic(diag, outputType.getShape());
677 diag << " to be compatible with inferred shape ";
678 printShapeToDiagnostic(diag, expectedShape);
679 return diag;
680}
681
683 Operation *op, const int64_t inputSize, const int64_t kernelSize,
684 const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
685 const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
686 const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
687 const llvm::StringRef padAfterName) {
688 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
689 return success();
690
691 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
692
693 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
694 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
695 stride);
696 if (!calculatedOutSizeMinusOne.has_value())
697 return op->emitOpError("expected input_")
698 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
699 << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
700 << dimAxis << " to be wholly divisible by stride_" << dimAxis
701 << ", got (" << inputSize << " - 1 + " << padBefore << " + "
702 << padAfter << " - (" << kernelSize << " - 1) * " << dilation
703 << ") / " << stride;
704
705 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
706 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
707 return op->emitOpError("calculated output ")
708 << dimName << " did not match expected: "
709 << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
710
711 return success();
712}
713
714//===----------------------------------------------------------------------===//
715// TOSA Operator Verifiers.
716//===----------------------------------------------------------------------===//
717
718template <typename T>
719static LogicalResult verifyConvOp(T op) {
720 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
721 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
722
723 auto inputEType = inputType.getElementType();
724 auto weightEType = weightType.getElementType();
725 auto biasEType =
726 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
727 auto resultEType =
728 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
729 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
730 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
731
732 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
733 inputEType = getStorageElementTypeFromQuantized(quantType);
734
735 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
736 weightEType = getStorageElementTypeFromQuantized(quantType);
737
738 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
739 biasEType = getStorageElementTypeFromQuantized(quantType);
740
741 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
742 resultEType = getStorageElementTypeFromQuantized(quantType);
743
744 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
745 // for now, only enforce bias element type == result element type for
746 // float types.
747 op.emitOpError(
748 "expect both bias and result to have same element type, got ")
749 << biasEType << " and " << resultEType;
750 return failure();
751 }
752
753 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
754 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
755 if (inputEType != weightEType) {
756 op.emitOpError(
757 "expect both input and weight to have same element type, got ")
758 << inputEType << " and " << weightEType;
759 return failure();
760 }
761 }
762
763 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
764 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
765
766 // Either both must be float or both non-float.
767 if (inputIsFloat != weightIsFloat) {
768 op.emitOpError(
769 "expect both input and weight to be float or not together, got ")
770 << inputEType << " and " << weightEType;
771 return failure();
772 }
773
774 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
775 if (inputEType != inputZpEType) {
776 return op.emitOpError("expect both input and its zero point are the same "
777 "element type, got ")
778 << inputEType << " and " << inputZpEType;
779 }
780
781 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
782 if (weightEType != weightZpEType) {
783 return op.emitOpError("expect both weight and its zero point are the same "
784 "element type, got ")
785 << weightEType << " and " << weightZpEType;
786 }
787
788 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
789 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
790 return failure();
791
792 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
793 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
794 return failure();
795
796 return success();
797}
798
799LogicalResult tosa::ConstOp::verify() {
800
801 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
802 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
803
804 if (!attrType || !outputType) {
805 emitOpError("expected tensors for attr/result type");
806 return failure();
807 }
808
809 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
810 outputType.getElementType())) {
811 if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
812 return success();
813 }
814
815 if (attrType.getElementType() != outputType.getElementType()) {
816 emitOpError("expected same attr/result element types");
817 return failure();
818 }
819
820 return success();
821}
822
823template <typename T>
824static LogicalResult verifyConvOpModes(T op) {
825 auto inputEType =
826 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
827
828 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
829 inputEType = getStorageElementTypeFromQuantized(quantType);
830
831 auto accType = op.getAccType();
832 if (inputEType.isInteger(8) && !accType.isInteger(32))
833 return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
834 << accType;
835
836 if (inputEType.isInteger(16) && !accType.isInteger(48))
837 return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
838 << accType;
839
840 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
841 !(accType.isF16() || accType.isF32()))
842 return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
843 << accType;
844
845 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
846 return op.emitOpError(
847 "accumulator type for f16 tensor is not f16/f32, got ")
848 << accType;
849
850 if (inputEType.isBF16() && !accType.isF32())
851 return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
852 << accType;
853
854 if (inputEType.isF32() && !accType.isF32())
855 return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
856 << accType;
857
858 auto resultEType =
859 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
860
861 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
862 resultEType = getStorageElementTypeFromQuantized(quantType);
863
864 return success();
865}
866
867//===----------------------------------------------------------------------===//
868// ERROR_IF functions.
869// ERROR_IF is a predicate that must set an error if the condition holds.
870//===----------------------------------------------------------------------===//
871
872template <typename T>
873static LogicalResult verifyConvOpErrorIf(T op) {
874 llvm::ArrayRef<int64_t> padding = op.getPad();
875 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
876 return op.emitOpError("expect all padding values to be >= 0, got ")
877 << padding;
878
879 llvm::ArrayRef<int64_t> strides = op.getStride();
880 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
881 return op.emitOpError("expect all stride values to be >= 1, got ")
882 << strides;
883
884 llvm::ArrayRef<int64_t> dilations = op.getDilation();
885 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
886 return op.emitOpError("expect all dilation values to be >= 1, got ")
887 << dilations;
888
889 const RankedTensorType outputType =
890 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
891 if (!outputType)
892 // Skip following checks if output is not ranked
893 return success();
894
895 const RankedTensorType inputType =
896 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
897 const RankedTensorType weightType =
898 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
899
900 if (inputType && weightType) {
901 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
902 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
903 if (failed(verifyConvOutputSize(
904 op, inputType.getDimSize(1), weightType.getDimSize(1),
905 outputType.getDimSize(1), padding[0], padding[1], strides[0],
906 dilations[0], "height", "y", "top", "bottom")))
907 return failure();
908
909 if (failed(verifyConvOutputSize(
910 op, inputType.getDimSize(2), weightType.getDimSize(2),
911 outputType.getDimSize(2), padding[2], padding[3], strides[1],
912 dilations[1], "width", "x", "left", "right")))
913 return failure();
914 }
915
916 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
917 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
918 if (failed(verifyConvOutputSize(
919 op, inputType.getDimSize(1), weightType.getDimSize(0),
920 outputType.getDimSize(1), padding[0], padding[1], strides[0],
921 dilations[0], "height", "y", "top", "bottom")))
922 return failure();
923
924 if (failed(verifyConvOutputSize(
925 op, inputType.getDimSize(2), weightType.getDimSize(1),
926 outputType.getDimSize(2), padding[2], padding[3], strides[1],
927 dilations[1], "width", "x", "left", "right")))
928 return failure();
929 }
930
931 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
932 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
933 if (failed(verifyConvOutputSize(
934 op, inputType.getDimSize(1), weightType.getDimSize(1),
935 outputType.getDimSize(1), padding[0], padding[1], strides[0],
936 dilations[0], "depth", "d", "front", "back")))
937 return failure();
938
939 if (failed(verifyConvOutputSize(
940 op, inputType.getDimSize(2), weightType.getDimSize(2),
941 outputType.getDimSize(2), padding[2], padding[3], strides[1],
942 dilations[1], "height", "y", "top", "bottom")))
943 return failure();
944
945 if (failed(verifyConvOutputSize(
946 op, inputType.getDimSize(3), weightType.getDimSize(3),
947 outputType.getDimSize(3), padding[4], padding[5], strides[2],
948 dilations[2], "width", "x", "left", "right")))
949 return failure();
950 }
951 }
952
953 const RankedTensorType biasType =
954 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
955 if (!biasType)
956 // Skip following checks if bias is not ranked
957 return success();
958
959 const int64_t biasChannels = biasType.getDimSize(0);
960 const int64_t outputChannels =
961 outputType.getDimSize(outputType.getRank() - 1);
962 if (biasChannels == ShapedType::kDynamic ||
963 outputChannels == ShapedType::kDynamic)
964 // Skip following checks if biasChannels or outputChannels is dynamic dim
965 return success();
966
967 if (biasChannels != outputChannels && biasChannels != 1)
968 return op.emitOpError(
969 "bias channels expected to be equal to output channels (")
970 << outputChannels << ") or 1, got " << biasChannels;
971
972 return success();
973}
974
975// Verify whether same type and shape of the given two types.
976static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
977 StringRef name1, Type type2,
978 StringRef name2) {
979 auto shapeType1 = dyn_cast<ShapedType>(type1);
980 auto shapeType2 = dyn_cast<ShapedType>(type2);
981 if (!shapeType1 || !shapeType2)
982 return failure();
983
984 auto elemType1 = shapeType1.getElementType();
985 auto elemType2 = shapeType2.getElementType();
986 if (elemType1 != elemType2)
987 return op->emitOpError()
988 << "require same element type for " << name1 << " (" << elemType1
989 << ") and " << name2 << " (" << elemType2 << ")";
990
991 if (failed(verifyCompatibleShape(type1, type2)))
992 return op->emitOpError()
993 << "require same shapes for " << name1 << " (" << type1 << ") and "
994 << name2 << " (" << type2 << ")";
995
996 return success();
997}
998
999// Verify whether same length, type, and shape of the given two tensor lists.
1000static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
1001 StringRef name1,
1002 ValueRange list2,
1003 StringRef name2) {
1004 if (list1.size() != list2.size())
1005 return op->emitOpError()
1006 << "require same number of values in " << name1 << " ("
1007 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
1008
1009 for (auto [type1, type2] :
1010 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
1011 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
1012 return failure();
1013 }
1014
1015 return success();
1016}
1017
1018static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
1019 ShapeAdaptor shapeAdaptor(type);
1020 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
1021 return success();
1022
1023 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
1024}
1025
1026template <typename T>
1027static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
1028 Operation *symTableOp =
1029 op->template getParentWithTrait<OpTrait::SymbolTable>();
1030 if (!symTableOp)
1031 // If the operation is not the scope of a symbol table, we cannot
1032 // verify it against it's declaration.
1033 return success();
1034
1035 SymbolTable symTable(symTableOp);
1036 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
1037
1038 // Verify prior declaration
1039 if (!varOp)
1040 return op->emitOpError("'")
1041 << op.getName() << "' has not been declared by 'tosa.variable'";
1042
1043 // Verify type and shape
1044 auto variableType = getVariableType(varOp);
1045 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
1046 "the input tensor")
1047 .failed())
1048 return failure();
1049 return success();
1050}
1051
1052// verify that inType and outType have same element types
1053template <typename T>
1054static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
1055 StringRef aName = "input",
1056 StringRef bName = "output") {
1057 auto aTType = llvm::dyn_cast<TensorType>(aType);
1058 auto bTType = llvm::dyn_cast<TensorType>(bType);
1059 if (!aTType) {
1060 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
1061 return failure();
1062 }
1063 if (!bTType) {
1064 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
1065 return failure();
1066 }
1067 auto aElementType = aTType.getElementType();
1068 auto bElementType = bTType.getElementType();
1069 auto aQuantType =
1070 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1071 auto bQuantType =
1072 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1073 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1074 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1075 aElementType != bElementType) {
1076 // only check if both element types are int/index/float/UniformQuantized
1077 // eg, not sure how to check quant::QuantizedType
1078 // this happens in test_conv2d_q_grouped_convolution in
1079 // tfl-to-tosa-pipeline.mlir
1080 op.emitOpError("expect ")
1081 << aName << " and " << bName << " to have same element type, got "
1082 << aElementType << " and " << bElementType;
1083 return failure();
1084 }
1085 return success();
1086}
1087
1088LogicalResult tosa::ArgMaxOp::verify() {
1089 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1090
1091 // Ensure output is of 32-bit integer
1092 if (const auto resultETy = resultType.getElementType();
1093 !resultETy.isIntOrIndex())
1094 return emitOpError("result tensor is not of integer type");
1095
1096 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1097 if (!inputType.hasRank())
1098 return success();
1099
1100 // Ensure axis is within the tensor rank
1101 const int64_t axis = getAxisAttr().getInt();
1102 if (((axis < 0) || axis >= inputType.getRank()))
1103 return emitOpError("specified axis is outside the rank of the tensor");
1104
1105 if (!resultType.hasRank())
1106 return success();
1107
1108 const ArrayRef<int64_t> inputShape = inputType.getShape();
1109 const ArrayRef<int64_t> outputShape = resultType.getShape();
1110 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1111 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1112 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1113 return emitOpError("expected output shape '")
1114 << expectedOutputShape << "', got '" << outputShape << "'";
1115
1116 return success();
1117}
1118
1119static LogicalResult verifyPoolingOpImpl(Operation *op,
1120 ArrayRef<int64_t> kernel,
1121 ArrayRef<int64_t> strides,
1122 ArrayRef<int64_t> padding, Value input,
1123 Value output) {
1124 const bool hasKernel = kernel.size() > 0;
1125 const bool hasStrides = strides.size() > 0;
1126 const bool hasPad = padding.size() > 0;
1127
1128 if (hasKernel && llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1129 return op->emitOpError("expect all kernel values to be >= 1, got ")
1130 << kernel;
1131
1132 if (hasStrides && llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1133 return op->emitOpError("expect all stride values to be >= 1, got ")
1134 << strides;
1135
1136 if (hasPad && llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1137 return op->emitOpError("expect all padding values to be >= 0, got ")
1138 << padding;
1139
1140 if (hasKernel && hasPad) {
1141 // Padding must be less than kernel size to avoid a divide-by-zero
1142 const int64_t kernelX = kernel[1];
1143 const int64_t padLeft = padding[2];
1144 const int64_t padRight = padding[3];
1145 if (padRight >= kernelX || padLeft >= kernelX)
1146 return op->emitOpError("expected left/right padding to be less than the "
1147 "width of the kernel, got pad_left=")
1148 << padLeft << ", pad_right=" << padRight
1149 << ", kernel_x=" << kernelX;
1150
1151 const int64_t kernelY = kernel[0];
1152 const int64_t padTop = padding[0];
1153 const int64_t padBottom = padding[1];
1154 if (padTop >= kernelY || padBottom >= kernelY)
1155 return op->emitOpError("expected top/bottom padding to be less than the "
1156 "height of the kernel, got pad_top=")
1157 << padTop << ", pad_bottom=" << padBottom
1158 << ", kernel_y=" << kernelY;
1159 }
1160
1161 const auto inputType = llvm::dyn_cast<RankedTensorType>(input.getType());
1162 const auto outputType = llvm::dyn_cast<RankedTensorType>(output.getType());
1163 if (!inputType || !outputType)
1164 return success();
1165
1166 if (hasKernel && hasStrides && hasPad) {
1167 const auto verifyOutputSize =
1168 [op](const int64_t inputSize, const int64_t outputSize,
1169 const int64_t kernelSize, const int64_t strideSize,
1170 const int64_t padBefore, const int64_t padAfter,
1171 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1172 const llvm::StringRef padBeforeName,
1173 const llvm::StringRef padAfterName) -> LogicalResult {
1174 if (ShapedType::isDynamic(inputSize))
1175 return success();
1176
1177 const std::optional<int64_t> calculatedOutSizeMinusOne =
1178 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1179 if (!calculatedOutSizeMinusOne.has_value())
1180 return op->emitOpError("expected input_")
1181 << dimName << " + pad_" << padBeforeName << " + pad_"
1182 << padAfterName << " - kernel_" << dimAxis
1183 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1184 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1185 << kernelSize << ") / " << strideSize;
1186
1187 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1188 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1189 return op->emitOpError("calculated output ")
1190 << dimName << " did not match expected: " << "calculated="
1191 << calculatedOutSize << ", expected=" << outputSize;
1192
1193 return success();
1194 };
1195
1196 if (failed(verifyOutputSize(inputType.getDimSize(1),
1197 outputType.getDimSize(1), kernel[0], strides[0],
1198 padding[0], padding[1], "height", "y", "top",
1199 "bottom")))
1200 return failure();
1201
1202 if (failed(verifyOutputSize(
1203 inputType.getDimSize(2), outputType.getDimSize(2), kernel[1],
1204 strides[1], padding[2], padding[3], "width", "x", "left", "right")))
1205 return failure();
1206 }
1207 return success();
1208}
1209
1210template <typename T>
1211static LogicalResult verifyPoolingOp(T op) {
1212 return verifyPoolingOpImpl(op.getOperation(), op.getKernel(), op.getStride(),
1213 op.getPad(), op.getInput(), op.getOutput());
1214}
1215
1216template <typename T>
1217static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op) {
1218 const Type inputETy = getStorageElementTypeOrSelf(op.getInput().getType());
1219 const Type resultETy = getStorageElementTypeOrSelf(op.getOutput().getType());
1220 const Type inputZpETy =
1221 getStorageElementTypeOrSelf(op.getInputZp().getType());
1222 const Type outputZpETy =
1223 getStorageElementTypeOrSelf(op.getOutputZp().getType());
1224
1225 auto accType = op.getAccType();
1226 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1227 return op.emitOpError("accumulator type for integer tensor is not i32");
1228
1229 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1230 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
1231
1232 if (inputETy.isBF16() && !accType.isF32())
1233 return op.emitOpError("accumulator type for bf16 tensor is not f32");
1234
1235 if (inputETy.isF32() && !accType.isF32())
1236 return op.emitOpError("accumulator type for f32 tensor is not f32");
1237
1238 if (inputETy != inputZpETy)
1239 return op.emitOpError("expect both input and its zero point are the same "
1240 "element type, got ")
1241 << inputETy << " and " << inputZpETy;
1242
1243 if (resultETy != outputZpETy)
1244 return op.emitOpError("expect both output and its zero point are the same "
1245 "element type, got ")
1246 << resultETy << " and " << outputZpETy;
1247
1248 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1249 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
1250 return failure();
1251
1252 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1253 if (succeeded(maybeOZp) && op.verifyOutputZeroPoint(*maybeOZp).failed())
1254 return failure();
1255
1256 return success();
1257}
1258
1259namespace {
1260struct AdaptivePoolingConstShapeValues {
1261 llvm::SmallVector<int64_t> kernel;
1262 llvm::SmallVector<int64_t> stride;
1263 llvm::SmallVector<int64_t> pad;
1264};
1265} // namespace
1266
1267template <typename T>
1269 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
1270 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
1271
1272template <typename T,
1273 typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
1274 int>::type = 0>
1276 T op, AdaptivePoolingConstShapeValues &values) {
1277 tosa::getConstShapeValues(op.getKernel().getDefiningOp(), values.kernel);
1278 tosa::getConstShapeValues(op.getStride().getDefiningOp(), values.stride);
1279 tosa::getConstShapeValues(op.getPad().getDefiningOp(), values.pad);
1280}
1281
1282LogicalResult tosa::AvgPool2dOp::verify() {
1283 if (failed(verifyPoolingOp(*this)))
1284 return failure();
1286 return failure();
1287 return success();
1288}
1289
1290LogicalResult tosa::AvgPool2dAdaptiveOp::verify() {
1291 AdaptivePoolingConstShapeValues values;
1293
1294 // If pad/stride/kernel are not constant, this is okay, we just can't check
1295 // their values. extractAdaptivePoolingConstShapeOperands will return an empty
1296 // list for each non CTC input. verifyPoolingOpImpl will need to handle values
1297 // not being present, and return success if they cannot be checked.
1298
1299 if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
1300 values.pad, getInput(), getOutput())))
1301 return failure();
1302
1304 return failure();
1305
1306 return success();
1307}
1308
1309LogicalResult tosa::ClampOp::verify() {
1310 mlir::Type inputETy =
1311 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1312 if (auto quantType =
1313 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1314 inputETy = getStorageElementTypeFromQuantized(quantType);
1315 }
1316 mlir::Type outputETy =
1317 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1318 if (auto quantType =
1319 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1320 outputETy = getStorageElementTypeFromQuantized(quantType);
1321 }
1322 if (inputETy != outputETy)
1323 return emitOpError("input/output element types are incompatible.");
1324
1325 auto maxValAttr = getMaxValAttr();
1326 auto minValAttr = getMinValAttr();
1327
1328 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1329
1330 if (inputETy.isInteger(dataTypeBitWidth)) {
1331 // if input datatype is integer, check that the min_val/max_val attributes
1332 // are integer attributes, and that their type is the same as the input's
1333 // datatype
1334 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1335 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1336 if (!intMaxValAttr || !intMinValAttr ||
1337 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1338 (intMaxValAttr.getType() != inputETy))
1339 return emitOpError("min/max attributes types are incompatible with "
1340 "input/output element types.");
1341
1342 const bool isUnsigned = inputETy.isUnsignedInteger();
1343 const bool isBoolean = inputETy.isInteger(1);
1344 const APInt minVal = intMinValAttr.getValue();
1345 const APInt maxVal = intMaxValAttr.getValue();
1346 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1347 return emitOpError("expected min_val <= max_val, got min_val=")
1348 << minValAttr << ", max_val=" << maxValAttr;
1349 } else {
1350 // otherwise, input datatype is float, check that the min_val/max_val
1351 // attributes share the same type and that their type is the same as the
1352 // input's datatype
1353 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1354 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1355 if (!floatMaxValAttr || !floatMinValAttr ||
1356 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1357 (floatMaxValAttr.getType() != inputETy))
1358 return emitOpError("min/max attributes types are incompatible with "
1359 "input/output element types.");
1360
1361 const APFloat minVal = floatMinValAttr.getValue();
1362 const APFloat maxVal = floatMaxValAttr.getValue();
1363 if (minVal.isNaN() || maxVal.isNaN())
1364 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1365 << minValAttr << ", max_val=" << maxValAttr;
1366
1367 if (maxVal < minVal)
1368 return emitOpError("expected min_val <= max_val, got min_val=")
1369 << minValAttr << ", max_val=" << maxValAttr;
1370 }
1371
1372 return success();
1373}
1374
1375//===----------------------------------------------------------------------===//
1376// TOSA Operator Quantization Builders.
1377//===----------------------------------------------------------------------===//
1378
1379/// This builder is called on all convolution operators except TransposeConv,
1380/// which has specialized output shape semantics. The builder also defines the
1381/// bitwidth of the output given the bit width of the input & weight content.
1383 Type outputType, Value input, Value weight,
1384 Value bias, DenseI64ArrayAttr pad,
1385 DenseI64ArrayAttr stride,
1386 DenseI64ArrayAttr dilation,
1387 TypeAttr accType) {
1388 auto zps = createZPsAsConst(builder, input, weight);
1389 result.addOperands({input, weight, bias, zps.first, zps.second});
1390 result.addAttribute("pad", pad);
1391 result.addAttribute("stride", stride);
1392 result.addAttribute("dilation", dilation);
1393 result.addAttribute("acc_type", accType);
1394 Type finalOutputType = outputType;
1395 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1396 if (quantAttr) {
1397 finalOutputType =
1398 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1399 }
1400 result.addTypes(finalOutputType);
1401}
1402
1403/// Handles tosa.transpose_conv2d which has outpad and output shape
1404/// attributes.
1405static void
1407 Type outputType, Value input, Value weight,
1408 Value bias, DenseI64ArrayAttr outpad,
1409 DenseI64ArrayAttr stride, TypeAttr accType) {
1410 auto zps = createZPsAsConst(builder, input, weight);
1411 result.addOperands({input, weight, bias, zps.first, zps.second});
1412 result.addAttribute("out_pad", outpad);
1413 result.addAttribute("stride", stride);
1414 result.addAttribute("acc_type", accType);
1415 Type finalOutputType = outputType;
1416 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1417 if (quantAttr) {
1418 finalOutputType =
1419 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1420 }
1421 result.addTypes(finalOutputType);
1422}
1423
1424/// The tosa.matmul op is also intended to be generated where a fully_connected
1425/// op must be constructed where the weight is not a constant. In this case,
1426/// the fully_connected op must be expressed using matmul.
1427/// TODO: Add link to the leglization document explaining this.
1429 OperationState &result, Type outputType,
1430 Value a, Value b) {
1431 auto zps = createZPsAsConst(builder, a, b);
1432 result.addOperands({a, b, zps.first, zps.second});
1433
1434 Type finalOutputType{outputType};
1435 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1436 auto eType = getStorageElementTypeOrSelf(a.getType());
1437 auto inputBits = eType.getIntOrFloatBitWidth();
1438
1439 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1440 assert(outputShapedType && "Output must be a shaped type");
1441
1442 IntegerType accElementType;
1443 if (inputBits == 16)
1444 accElementType = builder.getIntegerType(48);
1445 else
1446 accElementType = builder.getI32Type();
1447
1448 finalOutputType = outputShapedType.clone(accElementType);
1449 }
1450 result.addTypes(finalOutputType);
1451}
1452
1453/// Both the tosa.avg_pool2d and unary ops use the same
1454/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1455/// has additional parameters not part of the unary ops.
1456static void
1458 Type outputType, Value input,
1459 DenseArrayAttr kernel, DenseArrayAttr stride,
1460 DenseArrayAttr pad, TypeAttr accType) {
1461 const Location loc{result.location};
1462 int64_t inputZp{0};
1463 int64_t outputZp{0};
1464
1465 if (auto quantAttr =
1466 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1467 inputZp = quantAttr.getInputZp();
1468 outputZp = quantAttr.getOutputZp();
1469 }
1470 const std::optional<Value> inputZpOp =
1471 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1472 if (!inputZpOp) {
1473 (void)emitError(
1474 loc,
1475 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1476 }
1477 const std::optional<Value> outputZpOp =
1478 createZeroPointTensor(builder, loc, outputType, outputZp);
1479 if (!outputZpOp) {
1480 (void)emitError(loc, "Failed to create output zero point tensor for "
1481 "quantized AVG_POOL2D op");
1482 }
1483
1484 if (inputZpOp && outputZpOp) {
1485 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1486 } else {
1487 // failed to create one or more zero points above: just add input as
1488 // operands this will trigger error in building the op because of missing
1489 // zero points
1490 result.addOperands({input});
1491 }
1492 result.addAttribute("kernel", kernel);
1493 result.addAttribute("stride", stride);
1494 result.addAttribute("pad", pad);
1495 result.addAttribute("acc_type", accType);
1496 result.types.push_back(outputType);
1497}
1498
1499/// This builder mirrors avg_pool2d quant-info handling and materializes
1500/// kernel/stride/pad as const_shape operands for avg_pool2d_adaptive.
1502 OpBuilder &builder, OperationState &result, Type outputType, Value input,
1504 TypeAttr accType) {
1505 const Location loc{result.location};
1506 int64_t inputZp{0};
1507 int64_t outputZp{0};
1508
1509 if (auto quantAttr =
1510 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1511 inputZp = quantAttr.getInputZp();
1512 outputZp = quantAttr.getOutputZp();
1513 }
1514 const std::optional<Value> inputZpOp =
1515 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1516 if (!inputZpOp) {
1517 (void)emitError(loc,
1518 "Failed to create input zero point tensor for quantized "
1519 "AVG_POOL2D_ADAPTIVE op");
1520 }
1521 const std::optional<Value> outputZpOp =
1522 createZeroPointTensor(builder, loc, outputType, outputZp);
1523 if (!outputZpOp) {
1524 (void)emitError(loc, "Failed to create output zero point tensor for "
1525 "quantized AVG_POOL2D_ADAPTIVE op");
1526 }
1527
1528 if (inputZpOp && outputZpOp) {
1529 ImplicitLocOpBuilder b(loc, builder);
1530 Value kernelShape = getTosaConstShape(b, kernel.asArrayRef());
1531 Value strideShape = getTosaConstShape(b, stride.asArrayRef());
1532 Value padShape = getTosaConstShape(b, pad.asArrayRef());
1533 result.addOperands({input, inputZpOp.value(), outputZpOp.value(),
1534 kernelShape, strideShape, padShape});
1535 } else {
1536 // Failed to create one or more zero points above: just add input as
1537 // operands. This will trigger error in building the op because of missing
1538 // operands.
1539 result.addOperands({input});
1540 }
1541 result.addAttribute("acc_type", accType);
1542 result.types.push_back(outputType);
1543}
1544
1545/// This builder is called on single-parameter negate operator
1546/// to construct input and output zero points based on their
1547/// types.
1549 OperationState &result, Type outputType,
1550 Value input) {
1551 const Location loc{result.location};
1552 int64_t input1Zp{0};
1553 int64_t outputZp{0};
1554 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1555 if (quantAttr) {
1556 input1Zp = quantAttr.getInputZp();
1557 outputZp = quantAttr.getOutputZp();
1558 }
1559 const std::optional<Value> input1ZpOp =
1560 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1561 if (!input1ZpOp) {
1562 (void)emitError(
1563 loc, "Failed to create input1 zero point for quantized NEGATE op");
1564 }
1565
1566 const std::optional<Value> outputZpOp =
1567 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1568 if (!outputZpOp) {
1569 (void)emitError(
1570 loc, "Failed to create output zero point for quantized NEGATE op");
1571 }
1572
1573 if (input1ZpOp && outputZpOp) {
1574 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1575 } else {
1576 // failed to create one or more zero points above: just add input as
1577 // operands. This will trigger error in building the op because of
1578 // missing zero points
1579 result.addOperands({input});
1580 }
1581
1582 result.types.push_back(outputType);
1583}
1584
1585/// This builder is called on TOSA pad operator that needs to create its own
1586/// OptionalAttr quantization_attr parameter to scale the padding values
1587/// correctly. No pad_const is interpreted as zero-padding.
1589 Type outputType, Value input,
1590 Value paddings) {
1591 const Location loc{result.location};
1592 int32_t zp{0};
1593 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1594 if (quantAttr) {
1595 zp = static_cast<int32_t>(quantAttr.getInputZp());
1596 }
1597 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1598 result.addOperands({input, paddings, padConstOp});
1599 result.types.push_back(outputType);
1600}
1601
1603 StringRef name, Type variableType,
1604 Attribute initialValue) {
1605 const Location loc{result.location};
1606 auto nameAttr = builder.getStringAttr(name);
1607
1608 auto shapedType = dyn_cast<ShapedType>(variableType);
1609 if (!shapedType) {
1610 (void)emitError(loc, "variable type must be a shaped type");
1611 return;
1612 }
1613 if (!shapedType.hasRank()) {
1614 (void)emitError(loc, "variable type must be a ranked type");
1615 return;
1616 }
1617
1618 auto elementType = shapedType.getElementType();
1619 auto elementTypeAttr = TypeAttr::get(elementType);
1620 ArrayRef<int64_t> shape = shapedType.getShape();
1621 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1622
1623 result.addAttribute("sym_name", nameAttr);
1624 result.addAttribute("var_shape", varShapeAttr);
1625 result.addAttribute("type", elementTypeAttr);
1626 result.addAttribute("initial_value", initialValue);
1627}
1628
1629//===----------------------------------------------------------------------===//
1630// TOSA Operator Return Type Inference.
1631//===----------------------------------------------------------------------===//
1632static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1,
1633 const int64_t dim2) {
1634 if (dim1 == 1)
1635 return dim2;
1636 if (dim2 == 1)
1637 return dim1;
1638
1639 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1640 return failure();
1641
1642 // Prefer static dimension over dynamic
1643 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1644}
1645
1646static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1647 SmallVector<int64_t> &outShape) {
1648 int64_t outRank = 0;
1649 for (int i = 0, e = operands.size(); i != e; ++i) {
1650 auto shape = operands.getShape(i);
1651 if (!shape.hasRank()) {
1652 // TODO(jennik): Update function to have better case handling for
1653 // invalid operands and for ranked tensors.
1654 return failure();
1655 }
1656 outRank = std::max<int64_t>(outRank, shape.getRank());
1657 }
1658
1659 outShape.resize(outRank, 1);
1660
1661 for (int i = 0, e = operands.size(); i != e; ++i) {
1662 auto shape = operands.getShape(i);
1663 auto rankDiff = outShape.size() - shape.getRank();
1664
1665 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1666 auto dim1 = outShape[i + rankDiff];
1667 auto dim2 = shape.getDimSize(i);
1668
1669 const FailureOr<int64_t> maybeResolvedDim =
1670 resolveBroadcastDim(dim1, dim2);
1671 if (failed(maybeResolvedDim))
1672 return failure();
1673 const int64_t resolvedDim = *maybeResolvedDim;
1674 outShape[i + rankDiff] = resolvedDim;
1675 }
1676 }
1677
1678 return success();
1679}
1680
1681LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1682 MLIRContext *context, ::std::optional<Location> location,
1683 ArgMaxOp::Adaptor adaptor,
1684 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1685 ShapeAdaptor inputShape(adaptor.getInput().getType());
1686 IntegerAttr axis = adaptor.getProperties().axis;
1687 int32_t axisVal = axis.getValue().getSExtValue();
1688
1689 if (!inputShape.hasRank()) {
1690 inferredReturnShapes.push_back(ShapedTypeComponents());
1691 return success();
1692 }
1693
1694 SmallVector<int64_t> outShape;
1695 outShape.reserve(inputShape.getRank() - 1);
1696 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1697 if (i == axisVal)
1698 continue;
1699 outShape.push_back(inputShape.getDimSize(i));
1700 }
1701
1702 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1703 return success();
1704}
1705
1706LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1707 MLIRContext *context, ::std::optional<Location> location,
1708 RFFT2dOp::Adaptor adaptor,
1709 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1710 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1711
1712 if (!inputShape.hasRank())
1713 return failure();
1714
1715 llvm::SmallVector<int64_t> outputShape;
1716 outputShape.resize(3, ShapedType::kDynamic);
1717 outputShape[0] = inputShape.getDimSize(0);
1718 outputShape[1] = inputShape.getDimSize(1);
1719 int64_t inWidth = inputShape.getDimSize(2);
1720
1721 // Note that we can support this calculation symbolically
1722 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1723 if (inWidth != ShapedType::kDynamic)
1724 outputShape[2] = inWidth / 2 + 1;
1725
1726 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1727 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1728
1729 return success();
1730}
1731
1732static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1733 const llvm::StringRef dimName) {
1734 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1735 if (!isPowerOfTwo)
1736 return op->emitOpError("expected ")
1737 << dimName << " to be a power of two, got " << dimSize;
1738
1739 return success();
1740}
1741
1742LogicalResult tosa::RFFT2dOp::verify() {
1743 const auto outputTypes = getResultTypes();
1744 if (failed(verifyCompatibleShapes(outputTypes)))
1745 return emitOpError("expected output shapes to match, got ") << outputTypes;
1746
1747 const auto inputType =
1748 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1749 if (!inputType)
1750 return success();
1751
1752 const int64_t height = inputType.getDimSize(1);
1753 if (ShapedType::isStatic(height) &&
1754 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1755 return failure();
1756
1757 const int64_t width = inputType.getDimSize(2);
1758 if (ShapedType::isStatic(width) &&
1759 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1760 return failure();
1761
1762 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1763 if (!outputType)
1764 return success();
1765
1766 // Batch and height input/output dimensions should match
1767 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1768 outputType.getShape().drop_back())))
1769 return emitOpError("expected batch and height dimensions of input/output "
1770 "to match, got input=")
1771 << inputType << " output=" << outputType;
1772
1773 // Output width dimension expected to be input_width / 2 + 1
1774 const int64_t outputWidth = outputType.getDimSize(2);
1775 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1776 (outputWidth != (width / 2) + 1))
1777 return emitOpError(
1778 "expected output width to be equal to input_width / 2 + 1, got ")
1779 << outputWidth;
1780
1781 return success();
1782}
1783
1784LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1785 MLIRContext *context, ::std::optional<Location> location,
1786 FFT2dOp::Adaptor adaptor,
1787 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1788 inferredReturnShapes.push_back(
1789 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1790 inferredReturnShapes.push_back(
1791 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1792 return success();
1793}
1794
1795LogicalResult tosa::FFT2dOp::verify() {
1796 const auto inputRealType =
1797 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1798 const auto inputImagType =
1799 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1800 if (!inputRealType || !inputImagType)
1801 return success();
1802
1803 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1804 return ShapedType::isDynamic(a) ? a : b;
1805 };
1806
1807 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1808 inputImagType.getDimSize(1));
1809 if (ShapedType::isStatic(height) &&
1810 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1811 return failure();
1812
1813 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1814 inputImagType.getDimSize(2));
1815 if (ShapedType::isStatic(width) &&
1816 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1817 return failure();
1818
1819 return success();
1820}
1821
1822LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1823 MLIRContext *context, ::std::optional<Location> location,
1824 ConcatOp::Adaptor adaptor,
1825 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1826 // Infer all dimension sizes by reducing based on inputs.
1827 const Properties &prop = adaptor.getProperties();
1828 int32_t axis = prop.axis.getValue().getSExtValue();
1829 llvm::SmallVector<int64_t> outputShape;
1830 bool hasRankedInput = false;
1831 for (auto operand : adaptor.getOperands()) {
1832 ShapeAdaptor operandShape(operand.getType());
1833 if (!operandShape.hasRank())
1834 continue;
1835
1836 // Copy the Operand's rank.
1837 if (!hasRankedInput)
1838 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1839
1840 // Copy shapes until the dim is non-dynamic.
1841 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1842 if (i == axis || operandShape.isDynamicDim(i))
1843 continue;
1844 if (outputShape[i] == ShapedType::kDynamic)
1845 outputShape[i] = operandShape.getDimSize(i);
1846 if (outputShape[i] != operandShape.getDimSize(i))
1847 return emitOptionalError(location,
1848 "Cannot concat tensors with different sizes"
1849 " on the non-axis dimension ",
1850 i);
1851 }
1852
1853 hasRankedInput = true;
1854 }
1855
1856 if (adaptor.getInput1().empty())
1857 return failure();
1858
1859 Type inputType =
1860 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1861 if (!hasRankedInput) {
1862 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1863 return success();
1864 }
1865
1866 // Determine the dimension size along the concatenation axis.
1867 int64_t concatDimSize = 0;
1868 for (auto operand : adaptor.getOperands()) {
1869 ShapeAdaptor operandShape(operand.getType());
1870
1871 // We need to know the length of the concatenation axis of all inputs to
1872 // determine the dimension size of the output shape.
1873 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1874 concatDimSize = ShapedType::kDynamic;
1875 break;
1876 }
1877
1878 concatDimSize += operandShape.getDimSize(axis);
1879 }
1880
1881 outputShape[axis] = concatDimSize;
1882
1883 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1884 return success();
1885}
1886
1887LogicalResult tosa::ConcatOp::verify() {
1888 // check that each input has same element type as output
1889 auto outType = getOutput().getType();
1890 const Operation::operand_range inputList = getInput1();
1891
1892 // Check there is at least one input
1893 if (inputList.empty())
1894 return emitOpError("expect at least one input");
1895
1896 if (!llvm::all_of(inputList, [&](auto input) {
1897 return succeeded(verifySameElementTypes(
1898 *this, /* inType = */ input.getType(), outType));
1899 })) {
1900 return failure();
1901 }
1902
1903 const int32_t axis = getAxis();
1904 ShapeAdaptor firstRankedInputShape = nullptr;
1905 for (const auto &input : inputList) {
1906 const Type inputType = input.getType();
1907 ShapeAdaptor currShape(inputType);
1908 if (currShape.hasRank()) {
1909 firstRankedInputShape = currShape;
1910 // Check axis is in expected range
1911 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1912 return emitOpError("expect axis to be within range 0 < axis < "
1913 "rank(input1[firstRankedTensorIdx]), got ")
1914 << axis;
1915 break;
1916 }
1917 }
1918
1919 const auto allOperandsHasRank = [](const Value input) {
1920 return ShapeAdaptor(input.getType()).hasRank();
1921 };
1922 if (llvm::all_of(inputList, allOperandsHasRank)) {
1923 const int64_t firstInputRank = firstRankedInputShape.getRank();
1924
1925 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1926 const ShapeAdaptor inputShape(input.getType());
1927 const int64_t inputRank = inputShape.getRank();
1928 const size_t operandNum = index + 1;
1929
1930 // Check that each operand has the same rank
1931 if (inputRank != firstInputRank)
1932 return emitOpError(
1933 "expect all operands to have the same rank, but got ")
1934 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1935 << operandNum;
1936
1937 // Check non-axis dims match
1938 for (int i = 0; i < inputRank; i++) {
1939 const int64_t inputDim = inputShape.getDimSize(i);
1940 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1941 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1942 inputShape.isDynamicDim(i))
1943 continue;
1944 if (inputDim != firstInputDim)
1945 return emitOpError("expect all operand shapes to have the same sizes "
1946 "on non-axis dimensions, but got ")
1947 << inputDim << " vs " << firstInputDim << " at index " << i
1948 << " on operands 0 and " << operandNum;
1949 }
1950 }
1951
1952 const ShapeAdaptor outputShape(outType);
1953 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1954 return emitOpError("expect output rank to match inputs rank, got ")
1955 << outputShape.getRank() << " vs " << firstInputRank;
1956
1957 // ERROR_IF(axis_sum != shape[axis]);
1958 int64_t axisSum = 0;
1959 for (const auto &input : inputList) {
1960 const ShapeAdaptor inputShape(input.getType());
1961 if (inputShape.isDynamicDim(axis)) {
1962 // make axisSum negative to indicate invalid value
1963 axisSum = -1;
1964 break;
1965 }
1966 axisSum += inputShape.getDimSize(axis);
1967 }
1968
1969 if (axisSum >= 0 && outputShape.hasRank() &&
1970 !outputShape.isDynamicDim(axis) &&
1971 axisSum != outputShape.getDimSize(axis))
1972 return emitOpError("requires sum of axis dimensions of input1 "
1973 "equal to output axis dimension, got ")
1974 << axisSum << " and " << outputShape.getDimSize(axis);
1975 }
1976
1977 return success();
1978}
1979
1980LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1981 MLIRContext *context, ::std::optional<Location> location,
1982 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
1983 RegionRange regions,
1984 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1985 auto elementType = IntegerType::get(context, /*width=*/1);
1986
1988 if (resolveBroadcastShape(operands, outShape).failed()) {
1989 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1990 return success();
1991 }
1992
1993 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1994 return success();
1995}
1996
1997bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1998 if (l.size() != r.size() || l.size() != 1)
1999 return false;
2000 return succeeded(verifyCompatibleShape(l[0], r[0]));
2001}
2002
2003LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
2004 MLIRContext *context, ::std::optional<Location> location,
2005 MatMulOp::Adaptor adaptor,
2006 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2007 ShapeAdaptor lhsShape(adaptor.getA().getType());
2008 ShapeAdaptor rhsShape(adaptor.getB().getType());
2009
2010 // All shapes are dynamic.
2011 SmallVector<int64_t> outShape;
2012 outShape.resize(3, ShapedType::kDynamic);
2013
2014 if (lhsShape.hasRank()) {
2015 outShape[0] = lhsShape.getDimSize(0);
2016 outShape[1] = lhsShape.getDimSize(1);
2017 }
2018
2019 if (rhsShape.hasRank()) {
2020 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
2021 : outShape[0];
2022 outShape[2] = rhsShape.getDimSize(2);
2023 }
2024
2025 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2026 return success();
2027}
2028
2029LogicalResult MatMulOp::verify() {
2030 const ShapeAdaptor aShape(getA().getType());
2031 const ShapeAdaptor bShape(getB().getType());
2032 const Type aElementType = aShape.getElementType();
2033 const Type bElementType = bShape.getElementType();
2034
2035 const auto aQuantizedEType =
2036 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
2037 const auto bQuantizedEType =
2038 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
2039
2040 if (aQuantizedEType || bQuantizedEType) {
2041 if (!aQuantizedEType || !bQuantizedEType) {
2042 return emitOpError("expect operands to be both quantized or both not "
2043 "quantized, got ")
2044 << aElementType << " and " << bElementType;
2045 }
2046 // both a and b have quantized element types
2047 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
2048 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
2049 if (aQuantWidth != bQuantWidth) {
2050 return emitOpError("expect quantized operands to have same widths, got ")
2051 << aQuantWidth << " and " << bQuantWidth;
2052 }
2053 }
2054
2055 // check a_zp and b_zp
2056 auto aEType = getStorageElementTypeOrSelf(aElementType);
2057 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
2058 if (aEType != aZpEType)
2059 return emitOpError("expect input a and a_zp have the same "
2060 "element type, got ")
2061 << aEType << " and " << aZpEType;
2062
2063 const Type bEType = getStorageElementTypeOrSelf(bElementType);
2064 const Type bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
2065 if (bEType != bZpEType)
2066 return emitOpError("expect input b and b_zp have the same "
2067 "element type, got ")
2068 << bEType << " and " << bZpEType;
2069
2070 FailureOr<int64_t> maybeAZp = getAZeroPoint();
2071 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
2072 return failure();
2073
2074 FailureOr<int64_t> maybeBZp = getBZeroPoint();
2075 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
2076 return failure();
2077
2078 // Verify input/output shapes
2079 int64_t N = ShapedType::kDynamic;
2080 int64_t H = ShapedType::kDynamic;
2081 int64_t W = ShapedType::kDynamic;
2082 int64_t C = ShapedType::kDynamic;
2083
2084 if (aShape.hasRank()) {
2085 N = aShape.getDimSize(0);
2086 H = aShape.getDimSize(1);
2087 C = aShape.getDimSize(2);
2088 }
2089
2090 if (bShape.hasRank()) {
2091 if (failed(tryUpdateDimOrFailure(*this, N, bShape.getDimSize(0), "b",
2092 "batch")) ||
2093 failed(tryUpdateDimOrFailure(*this, C, bShape.getDimSize(1), "b",
2094 "channels")))
2095 return failure();
2096 W = bShape.getDimSize(2);
2097 }
2098
2099 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2100 const auto outputType = cast<ShapedType>(getResult().getType());
2101 if (outputType.hasRank() &&
2102 failed(
2103 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2104 InFlightDiagnostic opError = emitOpError("expected output shape ");
2105 printShapeToDiagnostic(opError, outputType.getShape());
2106 opError << " to be compatible with expected output shape ";
2107 printShapeToDiagnostic(opError, expectedOutputShape);
2108 return opError;
2109 }
2110
2111 return success();
2112}
2113
2114LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
2115 MLIRContext *context, ::std::optional<Location> location,
2116 MatmulTBlockScaledOp::Adaptor adaptor,
2117 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2118 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
2119
2120 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
2121 if (aDataShape.hasRank()) {
2122 outShape[0] = aDataShape.getDimSize(0);
2123 outShape[1] = aDataShape.getDimSize(1);
2124 }
2125
2126 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
2127 if (aScaleShape.hasRank()) {
2128 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
2129 : outShape[0];
2130 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
2131 : outShape[1];
2132 }
2133
2134 // If B batch size is 1, it is broadcast across A's batch size
2135 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
2136 if (bDataShape.hasRank()) {
2137 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
2138 if (bDataBatchSize != 1)
2139 outShape[0] =
2140 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
2141 outShape[2] = bDataShape.getDimSize(1);
2142 }
2143
2144 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
2145 if (bScaleShape.hasRank()) {
2146 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
2147 if (bScaleBatchSize != 1)
2148 outShape[0] =
2149 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
2150 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
2151 : outShape[2];
2152 }
2153
2154 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2155 return success();
2156}
2157
2158LogicalResult MatmulTBlockScaledOp::verify() {
2159 // Verify same input data types
2160 const Type aDataType = getAData().getType();
2161 const Type bDataType = getBData().getType();
2162 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
2163 "B_data")))
2164 return failure();
2165
2166 // Verify input shape compatibility
2167 int64_t N = ShapedType::kDynamic;
2168 int64_t D = ShapedType::kDynamic;
2169 int64_t H = ShapedType::kDynamic;
2170 int64_t W = ShapedType::kDynamic;
2171 int64_t C = ShapedType::kDynamic;
2172 int64_t multiplesOfC = ShapedType::kDynamic;
2173
2174 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
2175 if (aDataShape.hasRank()) {
2176 N = aDataShape.getDimSize(0);
2177 H = aDataShape.getDimSize(1);
2178 C = aDataShape.getDimSize(2);
2179 }
2180
2181 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
2182 if (aScaleShape.hasRank()) {
2183 if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
2184 "a_scale", "batch")) ||
2185 failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
2186 "a_scale", "height")))
2187 return failure();
2188 multiplesOfC = aScaleShape.getDimSize(2);
2189 }
2190
2191 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
2192 if (bDataShape.hasRank()) {
2193 if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
2194 "b_data", "batch")) ||
2195 failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
2196 "b_data", "channels")))
2197 return failure();
2198 W = bDataShape.getDimSize(1);
2199 }
2200
2201 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2202 if (bScaleShape.hasRank()) {
2203 if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
2204 "b_scale", "batch")) ||
2205 failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
2206 "b_scale", "width")) ||
2207 failed(tryUpdateDimOrFailure(*this, multiplesOfC,
2208 bScaleShape.getDimSize(2), "b_scale",
2209 "C/block_size")))
2210 return failure();
2211 }
2212
2213 // Verify batch size is broadcast compatible
2214 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2215 return emitOpError("expect B matrix batch size to be broadcast compatible "
2216 "with A, got D=")
2217 << D << " vs N=" << N;
2218
2219 // Verify C is a multiple of block size
2220 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2221 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2222 return emitOpError("expect block size to be 32, got ") << blockSize;
2223 if (ShapedType::isStatic(C) && C % blockSize != 0)
2224 return emitOpError("expect C to be a multiple of block size, got C=")
2225 << C << ", block_size=" << blockSize;
2226
2227 // Verify multiplesOfC is C / block size
2228 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2229 multiplesOfC != C / blockSize)
2230 return emitOpError(
2231 "expect scale operands dimension 2 to equal C/block_size (")
2232 << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
2233
2234 // Verify output shape
2235 N = ShapedType::isDynamic(N) ? D : N;
2236 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2237 const auto outputType = cast<ShapedType>(getResult().getType());
2238 if (outputType.hasRank() &&
2239 failed(
2240 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2241 InFlightDiagnostic opError = emitOpError("expected output shape ");
2242 printShapeToDiagnostic(opError, outputType.getShape());
2243 opError << " to be compatible with expected output shape ";
2244 printShapeToDiagnostic(opError, expectedOutputShape);
2245 return opError;
2246 }
2247
2248 return success();
2249}
2250
2251LogicalResult tosa::PadOp::inferReturnTypeComponents(
2252 MLIRContext *context, ::std::optional<Location> location,
2253 PadOp::Adaptor adaptor,
2254 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2255 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2256 auto paddingRank =
2257 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2258 SmallVector<int64_t> outputShape;
2259
2260 // If the input rank is unknown, we can infer the output rank using the
2261 // padding shape's rank divided by 2.
2262 if (!inputShape.hasRank()) {
2263 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2264 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2265 return success();
2266 }
2267
2268 SmallVector<int64_t> paddingValues;
2269 // If the paddings value is not a constant, all dimensions must be dynamic.
2270 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2271 paddingValues)) {
2272 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2273 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2274 return success();
2275 }
2276
2277 outputShape.reserve(inputShape.getRank());
2278 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2279 if (inputShape.isDynamicDim(i)) {
2280 outputShape.push_back(ShapedType::kDynamic);
2281 continue;
2282 }
2283 auto padFront = paddingValues[i * 2];
2284 auto padBack = paddingValues[i * 2 + 1];
2285 if (padFront < 0 || padBack < 0) {
2286 // if either padding for dim i is -1, output dim is unknown
2287 outputShape.push_back(ShapedType::kDynamic);
2288 continue;
2289 }
2290
2291 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2292 }
2293
2294 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2295 return success();
2296}
2297
2298LogicalResult tosa::PadOp::verify() {
2299 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2300 /* outType = */ getOutput().getType())
2301 .failed()) {
2302 return failure();
2303 }
2304
2305 if (auto padConst = getPadConst()) {
2306 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2307 /* outType = */ getOutput().getType())
2308 .failed()) {
2309 return failure();
2310 }
2311 }
2312
2313 RankedTensorType inputType =
2314 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2315 RankedTensorType outputType =
2316 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2317 if (!inputType || !outputType)
2318 return success();
2319
2320 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2321 "output")))
2322 return failure();
2323
2324 auto inputRank = inputType.getRank();
2325 DenseIntElementsAttr paddingAttr;
2326 if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
2327 return success();
2328
2329 auto paddingValues = paddingAttr.getValues<APInt>();
2330 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2331 return emitOpError() << "padding tensor must have " << inputRank
2332 << " * 2 = " << inputRank * 2 << " elements, but got "
2333 << paddingValues.size();
2334
2335 auto inputShape = inputType.getShape();
2336 auto outputShape = outputType.getShape();
2337
2338 for (int64_t i = 0; i < inputRank; ++i) {
2339 int64_t padStart = paddingValues[i * 2].getSExtValue();
2340 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2341
2342 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2343 return emitOpError()
2344 << "invalid padding values at dimension " << i
2345 << ": values must be non-negative or -1 for dynamic padding, got ["
2346 << padStart << ", " << padEnd << "]";
2347 }
2348
2349 // Skip shape verification for dynamic input/output
2350 if (inputShape[i] == ShapedType::kDynamic ||
2351 outputShape[i] == ShapedType::kDynamic)
2352 continue;
2353
2354 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2355 return emitOpError() << "mismatch in output shape at dimension " << i
2356 << ": expected " << inputShape[i] << " + "
2357 << padStart << " + " << padEnd << " = "
2358 << (inputShape[i] + padStart + padEnd)
2359 << ", but got " << outputShape[i];
2360 }
2361 }
2362
2363 return success();
2364}
2365
2366LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2367 MLIRContext *context, ::std::optional<Location> location,
2368 SliceOp::Adaptor adaptor,
2369 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2370
2371 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2374
2375 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2376 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2377 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2378 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2379 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2380 return success();
2381 }
2382
2383 // if size[i] is -1, all remaining elements in dimension i are included
2384 // in the slice, similar to TF.
2385 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2386 // initialize outputShape to all unknown
2387 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2388 if (inputShape.hasRank()) {
2389 for (size_t i = 0; i < size.size(); i++) {
2390 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2391 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2392 start[i] < inputShape.getDimSize(i))) {
2393 // size[i] is not 0 and not < -1, and start[i] is in valid range
2394 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2395 // input shape has unknown dim[i] - only valid if size[i] > 0
2396 if (size[i] > 0) {
2397 outputShape[i] = size[i];
2398 }
2399 } else {
2400 // input shape has known dim[i]
2401 if (size[i] == -1) {
2402 outputShape[i] = inputShape.getDimSize(i) - start[i];
2403 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2404 // start[i] + size[i] is within bound of input shape's dim[i]
2405 outputShape[i] = size[i];
2406 }
2407 }
2408 }
2409 }
2410 } else {
2411 outputShape = convertToMlirShape(size);
2412 }
2413 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2414 return success();
2415}
2416
2417LogicalResult tosa::SliceOp::verify() {
2418 const Value input = getInput1();
2419 const Value output = getOutput();
2420 if (verifySameElementTypes(*this, /* inType = */ input.getType(),
2421 /* outType = */ output.getType())
2422 .failed())
2423 return failure();
2424
2425 const Value start = getStart();
2426 const Value size = getSize();
2427 const ShapeAdaptor inputShape(input.getType());
2428 const ShapeAdaptor outputShape(output.getType());
2429
2430 if (inputShape.hasRank()) {
2431 const auto inputRank = inputShape.getRank();
2432 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2433 return emitOpError(
2434 "expect input1 and output to have the same ranks, got ")
2435 << inputRank << " and " << outputShape.getRank();
2436
2437 const auto startShapeRank =
2438 llvm::cast<tosa::shapeType>(start.getType()).getRank();
2439 if (inputRank != startShapeRank)
2440 return emitOpError("length of start is not equal to rank of input shape");
2441
2442 const auto sizeShapeRank =
2443 llvm::cast<tosa::shapeType>(size.getType()).getRank();
2444 if (inputRank != sizeShapeRank)
2445 return emitOpError("length of size is not equal to rank of input shape");
2446 }
2447
2448 SmallVector<int64_t> startValues;
2449 tosa::getConstShapeValues(start.getDefiningOp(), startValues);
2450 if (startValues.size()) {
2451 if (llvm::any_of(startValues, [](const int64_t v) {
2452 return v < 0 && v != kInferableDimSize;
2453 }))
2454 return emitOpError("start values must be non-negative, got [")
2455 << startValues << "]";
2456 }
2457
2458 SmallVector<int64_t> sizeValues;
2459 if (!tosa::getConstShapeValues(size.getDefiningOp(), sizeValues))
2460 return success();
2461
2462 if (llvm::any_of(sizeValues, [](const int64_t v) {
2463 return v <= 0 && v != kInferableDimSize;
2464 }))
2465 return emitOpError("size values must be > 0, got [") << sizeValues << "]";
2466 if (outputShape.hasRank()) {
2467 SmallVector<int64_t> outputDims;
2468 outputShape.getDims(outputDims);
2469 const bool hasNoInferableDims = llvm::all_of(
2470 sizeValues, [](const int64_t v) { return v != kInferableDimSize; });
2471 if (hasNoInferableDims &&
2472 failed(verifyCompatibleShape(outputDims, sizeValues)))
2473 return emitOpError("expected output shape to match size values, got ")
2474 << output.getType() << " vs [" << sizeValues << "]";
2475 }
2476
2477 if (inputShape.hasRank() && startValues.size()) {
2478 SmallVector<int64_t> inputDims;
2479 inputShape.getDims(inputDims);
2480 for (const auto &[index, vals] :
2481 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2482 const auto &[start, size, inputDim] = vals;
2483 if (start == kInferableDimSize || size == kInferableDimSize ||
2484 ShapedType::isDynamic(inputDim))
2485 continue;
2486 if (start + size > inputDim)
2487 return emitOpError("start + size must be less than or equal to input "
2488 "dimension size, got start=")
2489 << start << ", size=" << size
2490 << " vs input dim size=" << inputDim << " at dimension "
2491 << index;
2492 }
2493 }
2494
2495 return success();
2496}
2497
2498LogicalResult tosa::MulOp::inferReturnTypeComponents(
2499 MLIRContext *context, ::std::optional<Location> location,
2500 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
2501 RegionRange regions,
2502 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2503 // mul op's output shape only depend on input1 and input2, not on shift
2504 ValueShapeRange twoInputs = operands.drop_back();
2506 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2507 inferredReturnShapes.push_back(ShapedTypeComponents());
2508 } else {
2509 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2510 }
2511 return success();
2512}
2513
2514LogicalResult tosa::MulOp::verify() {
2515 const Value output = getOutput();
2516 auto resElemType = getElementTypeOrSelf(output);
2517
2518 // Verify if the element type among operands and result match tosa
2519 // specification.
2520 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2521 IntegerType lhsIntType =
2522 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2523 IntegerType rhsIntType =
2524 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2525 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2526 return emitOpError("requires the same element type for all operands");
2527
2528 // Though the spec requires the element type of result to be i32, a more
2529 // relaxed way is provided at dialect level for easier cooperating with
2530 // other dialects.
2531 if (lhsIntType.getWidth() > resIntType.getWidth())
2532 return emitOpError("invalid data type size for operands or result");
2533
2534 } else {
2535 // For other supported type, the spec requires requires the same element
2536 // type for all operands (excludes `shift` operand) and results.
2537 for (int i = 0; i < 2; ++i) {
2538 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2539 return emitOpError(
2540 "requires the same element type for all operands and results");
2541 }
2542
2543 // verify shift has value 0 for non-integer types
2544 ElementsAttr shiftElem;
2545 if (matchPattern(getShift(), m_Constant(&shiftElem))) {
2546 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2547 if (shift != 0) {
2548 return emitOpError() << "require shift to be 0 for float type";
2549 }
2550 }
2551 }
2552
2553 // Verify the op has same ranks for all main operands (excludes extra operands
2554 // such as shift of mul op, so this is the only difference with the built-in
2555 // `SameOperandsAndResultRank` trait) and results types, if known.
2556 TypeRange operandTypes = getOperandTypes();
2557 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2558 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2559
2560 const bool aHasRank = aType.hasRank();
2561 const bool bHasRank = bType.hasRank();
2562
2563 bool hasExpectedOutputShape = false;
2564 SmallVector<int64_t> expectedOutputShape;
2565
2566 if (aHasRank && bHasRank) {
2567 const int64_t aRank = aType.getRank();
2568 const int64_t bRank = bType.getRank();
2569 if (aRank != bRank)
2570 return emitOpError("a and b operands don't have matching ranks, got ")
2571 << aRank << " and " << bRank;
2572
2573 // check for broadcast compatible shapes
2575 aType.getShape(), bType.getShape(), expectedOutputShape))
2576 return emitOpError("a and b operands don't have broadcast-compatible "
2577 "shapes, got ")
2578 << aType << " and " << bType;
2579 hasExpectedOutputShape = true;
2580 }
2581
2582 ShapedType resultType = cast<ShapedType>(output.getType());
2583 if (!resultType.hasRank())
2584 return success();
2585
2586 const int64_t resultRank = resultType.getRank();
2587 if (aHasRank && resultRank != aType.getRank())
2588 return emitOpError("result type has different rank than a, got ")
2589 << resultRank << " vs " << aType.getRank();
2590 if (bHasRank && resultRank != bType.getRank())
2591 return emitOpError("result type has different rank than b, got ")
2592 << resultRank << " vs " << bType.getRank();
2593
2594 if (hasExpectedOutputShape &&
2595 failed(verifyOutputShapeCompatibleWithExpected(getOperation(), resultType,
2596 expectedOutputShape)))
2597 return failure();
2598
2599 return success();
2600}
2601
2602LogicalResult tosa::TableOp::inferReturnTypeComponents(
2603 MLIRContext *context, ::std::optional<Location> location,
2604 TableOp::Adaptor adaptor,
2605 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2606 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2607
2608 if (!inputShape.hasRank()) {
2609 inferredReturnShapes.push_back(ShapedTypeComponents());
2610 return success();
2611 }
2612
2613 inferredReturnShapes.resize(1);
2614 inputShape.getDims(inferredReturnShapes[0]);
2615 return success();
2616}
2617
2618LogicalResult tosa::TableOp::verify() {
2619 const TensorType inputType = getInput1().getType();
2620 const TensorType outputType = getOutput().getType();
2621
2622 if (!inputType.hasRank() || !outputType.hasRank())
2623 return success();
2624
2625 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2626 "result")))
2627 return failure();
2628
2629 auto inputDims = inputType.getShape();
2630 auto outputDims = outputType.getShape();
2631 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2632 int64_t dim = it.index();
2633 auto [inputDim, outputDim] = it.value();
2634 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2635 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2636 << " doesn't match dim(input, " << dim
2637 << ") = " << inputDim;
2638 }
2639 }
2640 return success();
2641}
2642
2643LogicalResult
2644tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2645 // Multiples must be constants.
2646 DenseIntElementsAttr multiplesAttr;
2647 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2648 return failure();
2649 multiples =
2650 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2651 [](const APInt &val) { return val.getSExtValue(); });
2652 return success();
2653}
2654
2655LogicalResult tosa::TileOp::inferReturnTypeComponents(
2656 MLIRContext *context, ::std::optional<Location> location,
2657 TileOp::Adaptor adaptor,
2658 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2659 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2660 SmallVector<int64_t> multiples;
2661 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2662 multiples)) {
2663 auto rank =
2664 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2665 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2666 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2667 return success();
2668 }
2669 multiples = convertToMlirShape(multiples);
2670
2671 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2672 SmallVector<int64_t> outputShape;
2673 if (!inputShape.hasRank()) {
2674 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2675 inferredReturnShapes.push_back(
2676 ShapedTypeComponents(outputShape, inputType));
2677 return success();
2678 }
2679 if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2680 return failure();
2681
2682 // Any non dynamic dimension can be multiplied to a known size.
2683 outputShape.reserve(multiples.size());
2684 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2685 if (multiples[i] == ShapedType::kDynamic) {
2686 outputShape.push_back(ShapedType::kDynamic);
2687 } else {
2688 int64_t dim = inputShape.getDimSize(i);
2689 if (dim != ShapedType::kDynamic)
2690 dim *= multiples[i];
2691 outputShape.push_back(dim);
2692 }
2693 }
2694
2695 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2696 return success();
2697}
2698
2699LogicalResult tosa::TileOp::verify() {
2700 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2701 /* outType = */ getOutput().getType())
2702 .failed()) {
2703 return failure();
2704 }
2705 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2706 ShapedType outputType = llvm::cast<ShapedType>(getType());
2707
2708 shapeType multiplesType =
2709 llvm::cast<tosa::shapeType>(getMultiples().getType());
2710
2711 auto multiplesRank = multiplesType.getRank();
2712
2713 if (inputType.hasRank()) {
2714 if (inputType.getRank() != multiplesRank)
2715 return emitOpError("expect 'multiples' to have rank ")
2716 << inputType.getRank() << " but got " << multiplesRank << ".";
2717 if (outputType.hasRank() &&
2718 failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2719 "output")))
2720 return failure();
2721 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2722 return emitOpError("expect 'multiples' array to have length ")
2723 << outputType.getRank() << " but got " << multiplesRank << ".";
2724
2725 SmallVector<int64_t> multiples;
2726 if (getConstantMultiples(multiples).succeeded() &&
2727 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2728 return emitOpError(
2729 "expect element of 'multiples' to be positive integer or -1.");
2730
2731 return success();
2732}
2733
2734bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2735 if (l.size() != r.size() || l.size() != 1)
2736 return false;
2737 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2738}
2739
2740LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2741 MLIRContext *context, ::std::optional<Location> location,
2742 ReshapeOp::Adaptor adaptor,
2743 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2744 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2745 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2746 llvm::SmallVector<int64_t> newShapeValue;
2747 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2748 newShapeValue)) {
2749 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2750 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2751 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2752 return success();
2753 }
2754 newShapeValue = convertToMlirShape(newShapeValue);
2755
2756 // We cannot infer from the total number of elements so we must take the
2757 // shape attribute as exact.
2758 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2759 inferredReturnShapes.push_back(
2760 ShapedTypeComponents(newShapeValue, inputType));
2761 return success();
2762 }
2763
2764 // Determine the number of elements covered by the slice of all static
2765 // dimensions. This allows us to infer the length of the remaining dynamic
2766 // dimension.
2767 int64_t numElements = inputShape.getNumElements();
2768 int64_t staticMul = 1;
2769 for (auto val : newShapeValue) {
2770 if (ShapedType::isStatic(val)) {
2771 staticMul *= val;
2772 }
2773 }
2774
2775 // Determine the length of the dynamic dimension.
2776 for (auto &val : newShapeValue) {
2777 if (ShapedType::isDynamic(val))
2778 val = numElements / staticMul;
2779 }
2780
2781 inferredReturnShapes.push_back(
2782 ShapedTypeComponents(newShapeValue, inputType));
2783 return success();
2784}
2785
2786llvm::LogicalResult tosa::ReshapeOp::verify() {
2787 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2788 /* outType = */ getOutput().getType())
2789 .failed()) {
2790 return failure();
2791 }
2792 TensorType inputType = getInput1().getType();
2793
2794 SmallVector<int64_t> shapeValues;
2795 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2796 // skip following checks if shape is not constant
2797 return mlir::success();
2798 }
2799
2800 int missingDims = llvm::count(shapeValues, kInferableDimSize);
2801 if (missingDims > 1)
2802 return emitOpError() << "expected at most one target dimension to be "
2804
2805 const auto outputType = dyn_cast<RankedTensorType>(getType());
2806 if (!outputType)
2807 return success();
2808
2809 if ((int64_t)shapeValues.size() != outputType.getRank())
2810 return emitOpError() << "new shape does not match result rank";
2811
2812 for (auto [newShapeDim, outputShapeDim] :
2813 zip(shapeValues, outputType.getShape())) {
2814 if (newShapeDim != kInferableDimSize &&
2815 newShapeDim != ShapedType::kDynamic &&
2816 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2817 return emitOpError() << "new shape is inconsistent with result shape";
2818
2819 if (newShapeDim != ShapedType::kDynamic && newShapeDim < kInferableDimSize)
2820 return emitOpError() << "new shape has invalid tensor dimension size "
2821 << newShapeDim;
2822 }
2823
2824 if (inputType.hasStaticShape()) {
2825 int64_t inputElementsNum = inputType.getNumElements();
2826 if (outputType.hasStaticShape()) {
2827 int64_t outputElementsNum = outputType.getNumElements();
2828 if (inputElementsNum != outputElementsNum) {
2829 return emitOpError() << "cannot reshape " << inputElementsNum
2830 << " elements into " << outputElementsNum;
2831 }
2832 }
2833
2834 int64_t newShapeElementsNum =
2835 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2836 return (dim > 0) ? acc * dim : acc;
2837 });
2838 bool isStaticNewShape =
2839 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2840 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2841 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2842 return emitOpError() << "cannot reshape " << inputElementsNum
2843 << " elements into " << newShapeElementsNum;
2844 }
2845 }
2846
2847 return mlir::success();
2848}
2849
2850bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(TypeRange l,
2851 TypeRange r) {
2852 if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
2853 return false;
2854 bool ok = (getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]));
2855 if (l.size() == 2)
2856 ok = ok && (getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]));
2857 return ok;
2858}
2859
2860LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
2861 MLIRContext *context, ::std::optional<Location> location,
2862 ReshapeBlockScaledOp::Adaptor adaptor,
2863 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2864
2865 const auto numInputs = adaptor.getInput().size();
2866 ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
2867 Type inputType = getElementTypeOrSelf(adaptor.getInput()[0].getType());
2868 llvm::SmallVector<int64_t> newShapeValue;
2869 const auto newShape = adaptor.getNewValueShape();
2870 if (!tosa::getConstShapeValues(newShape.getDefiningOp(), newShapeValue)) {
2871 auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
2872 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2873 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2874 if (numInputs == 2)
2875 inferredReturnShapes.push_back(ShapedTypeComponents(
2876 fallback, getElementTypeOrSelf(adaptor.getInput()[1].getType())));
2877 return success();
2878 }
2879
2880 const uint32_t blockSize =
2881 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
2882
2883 llvm::SmallVector<int64_t> newScaleShapeValue;
2884 if (numInputs == 2) {
2885 newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
2886 if (ShapedType::isStatic(newScaleShapeValue.back()))
2887 newScaleShapeValue.back() /= blockSize;
2888 }
2889
2890 inferredReturnShapes.push_back(
2891 ShapedTypeComponents(newShapeValue, inputType));
2892 if (numInputs == 2) {
2893 // Fix up scale shape - with special case for last dimension
2894 for (size_t idx = 0; idx < newShapeValue.size(); idx++) {
2895 if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
2896 newScaleShapeValue[idx] = newShapeValue[idx];
2897 if (idx == (newShapeValue.size() - 1))
2898 newScaleShapeValue[idx] /= blockSize;
2899 }
2900 }
2901
2902 inferredReturnShapes.push_back(ShapedTypeComponents(
2903 newScaleShapeValue,
2904 getElementTypeOrSelf(adaptor.getInput()[1].getType())));
2905 }
2906 return success();
2907}
2908
2909llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
2910 const Operation::operand_range inputList = getInput();
2911 const Operation::result_range outputList = getResults();
2912
2913 if (inputList.size() == 0)
2914 return emitOpError("requires at least one input");
2915
2916 if (inputList.size() > 2)
2917 return emitOpError("requires at most two inputs");
2918
2919 if (inputList.size() != outputList.size())
2920 return emitOpError("requires number of results to match inputs");
2921
2922 if (verifySameElementTypes(*this, /* inType = */ inputList[0].getType(),
2923 /* outType = */ outputList[0].getType())
2924 .failed()) {
2925 return failure();
2926 }
2927
2928 const auto inputType = llvm::cast<ShapedType>(inputList[0].getType());
2929 if (!inputType.hasRank())
2930 return success();
2931 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2932
2933 if (inputList.size() == 2) {
2934 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2935 return emitOpError("expect block size to be 32, got ") << blockSize;
2936 if (llvm::any_of(inputList, [](Value v) {
2937 const auto input = cast<ShapedType>(v.getType());
2938 return input.hasRank() && input.getRank() == 0;
2939 }))
2940 return emitOpError(
2941 "requires all input shapes have a rank greater than 0");
2942 if (llvm::any_of(outputList, [](Value v) {
2943 const auto output = cast<ShapedType>(v.getType());
2944 return output.hasRank() && output.getRank() == 0;
2945 }))
2946 return emitOpError(
2947 "requires all result shapes have a rank greater than 0");
2948
2949 if (verifySameElementTypes(*this, /* inType = */ inputList[1].getType(),
2950 /* outType = */ outputList[1].getType())
2951 .failed()) {
2952 return failure();
2953 }
2954
2955 const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].getType());
2956 if (inputScaleType.hasRank()) {
2957 if (inputType.getRank() != inputScaleType.getRank())
2958 return emitOpError("input shapes do not have same rank");
2959
2960 // Check all but the last dimension that the input shape dimensions match
2961 for (auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
2962 const int64_t inputValueDim = inputType.getDimSize(dimIdx);
2963 const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
2964 if (ShapedType::isStatic(inputValueDim) &&
2965 ShapedType::isStatic(inputScaleDim) &&
2966 inputValueDim != inputScaleDim)
2967 return emitOpError("input shapes for data and scale do not match on "
2968 "dimension ")
2969 << dimIdx;
2970 }
2971
2972 // Verify last dimension of input is a multiple of block size
2973 const int64_t lastValueDim =
2974 inputType.getDimSize(inputType.getRank() - 1);
2975 if (ShapedType::isStatic(lastValueDim)) {
2976 if (lastValueDim % blockSize != 0)
2977 return emitOpError("expect last dimension of input_data (")
2978 << lastValueDim << ") to be divisible by block_size ("
2979 << blockSize << ")";
2980
2981 const int64_t lastScaleDim =
2982 inputScaleType.getDimSize(inputScaleType.getRank() - 1);
2983 // Verify last dimension of scale is lastValueDim / block size
2984 if (ShapedType::isStatic(lastScaleDim) &&
2985 lastScaleDim != lastValueDim / blockSize)
2986 return emitOpError("expect last dimension of scale_data (")
2987 << lastScaleDim << ") to be " << lastValueDim << "/"
2988 << blockSize;
2989 }
2990 }
2991 } else {
2992 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
2993 return emitOpError("expect block size to be 1, got ") << blockSize;
2994 }
2995
2996 // Get the new value shape dimension values
2997 SmallVector<int64_t> shapeValues;
2998 if (!tosa::getConstShapeValues(getNewValueShape().getDefiningOp(),
2999 shapeValues)) {
3000 // skip following checks if shape is not constant
3001 return mlir::success();
3002 }
3003
3004 if (inputList.size() == 2) {
3005 if (static_cast<int64_t>(shapeValues.size()) == 0)
3006 return emitOpError("requires new shape to have a rank greater than 0");
3007
3008 const int64_t lastShapeDim = shapeValues.back();
3009 if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
3010 return emitOpError("expect last dimension of new shape (")
3011 << lastShapeDim << ") to be divisible by block_size (" << blockSize
3012 << ")";
3013 }
3014
3015 const auto outputType = llvm::cast<ShapedType>(outputList[0].getType());
3016 if (!outputType.hasRank())
3017 return success();
3018
3019 if (static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
3020 return emitOpError() << "result does not match new shape rank";
3021
3022 for (auto [newShapeDim, outputShapeDim] :
3023 zip(shapeValues, outputType.getShape())) {
3024 if (ShapedType::isStatic(newShapeDim) &&
3025 ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
3026 return emitOpError() << "result shape is inconsistent with new shape";
3027 }
3028
3029 if (outputList.size() == 2) {
3030 // Set up scale shape from new shape given
3031 SmallVector<int64_t> scaleShapeValues(shapeValues.begin(),
3032 shapeValues.end());
3033 scaleShapeValues.back() /= blockSize;
3034
3035 const auto outputScaleType =
3036 llvm::cast<ShapedType>(outputList[1].getType());
3037 if (outputScaleType.hasRank()) {
3038 if ((int64_t)scaleShapeValues.size() != outputScaleType.getRank())
3039 return emitOpError() << "result scale does not match new shape rank";
3040
3041 for (auto [newScaleShapeDim, outputScaleShapeDim] :
3042 zip(scaleShapeValues, outputScaleType.getShape())) {
3043 if (ShapedType::isStatic(newScaleShapeDim) &&
3044 ShapedType::isStatic(outputScaleShapeDim) &&
3045 newScaleShapeDim != outputScaleShapeDim)
3046 return emitOpError()
3047 << "result scale shape is inconsistent with new shape";
3048 }
3049 }
3050 }
3051
3052 if (inputType.hasStaticShape()) {
3053 int64_t inputElementsNum = inputType.getNumElements();
3054 if (outputType.hasStaticShape()) {
3055 int64_t outputElementsNum = outputType.getNumElements();
3056 if (inputElementsNum != outputElementsNum) {
3057 return emitOpError() << "cannot reshape " << inputElementsNum
3058 << " elements into " << outputElementsNum;
3059 }
3060 }
3061
3062 int64_t newShapeElementsNum =
3063 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
3064 return (dim > 0) ? acc * dim : acc;
3065 });
3066 bool isStaticNewShape =
3067 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
3068 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
3069 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
3070 return emitOpError() << "cannot reshape " << inputElementsNum
3071 << " elements into " << newShapeElementsNum;
3072 }
3073 }
3074
3075 return mlir::success();
3076}
3077
3078// return failure if val is not a constant
3079// set zp to -1 if val is non-zero float or val is not integer nor float
3080// otherwise set zp to val's constant value
3081static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
3082 ElementsAttr zpAttr;
3083 if (!matchPattern(val, m_Constant(&zpAttr))) {
3084 return failure();
3085 }
3086
3087 Type zpElemType = zpAttr.getElementType();
3088
3089 if (llvm::isa<FloatType>(zpElemType)) {
3090 if (zpAttr.getValues<APFloat>()[0].isZero()) {
3091 return 0;
3092 }
3093 // return non-zero value to trigger error check
3094 return -1;
3095 }
3096
3097 if (llvm::isa<IntegerType>(zpElemType)) {
3098 if (signExtend)
3099 return zpAttr.getValues<APInt>()[0].getSExtValue();
3100 return zpAttr.getValues<APInt>()[0].getZExtValue();
3101 }
3102
3103 // return non-zero value to trigger error check
3104 return -1;
3105}
3106
3107static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
3108 ElementsAttr attr;
3109 if (!matchPattern(val, m_Constant(&attr)))
3110 return failure();
3111
3112 if (!llvm::isa<IntegerType>(attr.getElementType()) ||
3113 attr.getNumElements() != 1)
3114 return failure();
3115
3116 return attr.getValues<APInt>()[0].getSExtValue();
3117}
3118
3119template <typename T>
3120static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
3121 const std::string &operand) {
3122 Type zpElemType = getElementTypeOrSelf(val);
3123
3124 if (!zpElemType.isInteger(8) && zp != 0) {
3125 // convert operand to lower case for error message
3126 std::string lower = operand;
3127 llvm::transform(lower, lower.begin(), ::tolower);
3128 return op.emitOpError()
3129 << lower << " zero point must be zero for non-int8 integer types";
3130 }
3131
3132 return success();
3133}
3134
3135static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
3136 const int64_t &zp,
3137 const std::string &operand) {
3138 bool isInputZp = (operand == "Input");
3139
3140 bool tensorUnsigned =
3141 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
3142 StringRef tensorName = isInputZp ? "input" : "output";
3143
3144 Type zpElemType = getElementTypeOrSelf(zpVal);
3145
3146 if (zp != 0) {
3147 if (!zpElemType.isInteger(8) &&
3148 !(zpElemType.isInteger(16) && tensorUnsigned)) {
3149 return op.emitOpError()
3150 << "expect " << tensorName << "_zp of 0, got " << zp;
3151 }
3152 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
3153 return op.emitOpError() << "expect " << tensorName
3154 << "_zp of 0 or 32768 for unsigned int16 "
3155 << tensorName << ", got " << zp;
3156 }
3157 }
3158
3159 return success();
3160}
3161
3162#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
3163 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
3164 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
3165 } \
3166 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
3167 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
3168 }
3169
3170ZERO_POINT_HELPER(Conv2DOp, Input, true)
3171ZERO_POINT_HELPER(Conv2DOp, Weight, true)
3172ZERO_POINT_HELPER(Conv3DOp, Input, true)
3173ZERO_POINT_HELPER(Conv3DOp, Weight, true)
3174ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
3175ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
3176ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
3177ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
3178ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
3179ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
3180ZERO_POINT_HELPER(AvgPool2dAdaptiveOp, Input, true)
3181ZERO_POINT_HELPER(AvgPool2dAdaptiveOp, Output, true)
3182ZERO_POINT_HELPER(MatMulOp, A, true)
3183ZERO_POINT_HELPER(MatMulOp, B, true)
3184ZERO_POINT_HELPER(NegateOp, Input1, true)
3185ZERO_POINT_HELPER(NegateOp, Output, true)
3186ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
3187ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
3188#undef ZERO_POINT_HELPER
3189
3190LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
3191 MLIRContext *context, ::std::optional<Location> location,
3192 TransposeOp::Adaptor adaptor,
3193 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3194 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3195
3196 // If input rank and permutation length is unknown, the output rank is
3197 // unknown.
3198 if (!inputShape.hasRank()) {
3199 inferredReturnShapes.push_back(ShapedTypeComponents());
3200 return success();
3201 }
3202
3203 const auto inputRank = inputShape.getRank();
3204
3205 // This would imply the number of permutations does not match the rank of
3206 // the input which is illegal.
3207 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
3208 return failure();
3209 }
3210
3211 SmallVector<int64_t> outputShape;
3212 // Rank-0 means no permutations matter.
3213 if (inputRank == 0) {
3214 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3215 return success();
3216 }
3217
3218 // Check whether the input dimensions are all the same.
3219 bool allTheSame = true;
3220 for (int i = 1, s = inputRank; i < s; i++) {
3221 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
3222 allTheSame = false;
3223 break;
3224 }
3225 }
3226
3227 // If all of the input dimensions are the same we don't care about the
3228 // permutation.
3229 if (allTheSame) {
3230 outputShape.resize(inputRank, inputShape.getDimSize(0));
3231 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3232 return success();
3233 }
3234
3235 outputShape.resize(inputRank, ShapedType::kDynamic);
3236
3237 // Constant permutation values must be within the input rank.
3238 if (llvm::any_of(adaptor.getPerms(),
3239 [inputRank](const auto i) { return i >= inputRank; }))
3240 return failure();
3241
3242 outputShape.reserve(inputRank);
3243 for (int i = 0, s = inputRank; i < s; i++) {
3244 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
3245 }
3246
3247 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3248 return success();
3249}
3250
3251LogicalResult tosa::TransposeOp::verify() {
3252 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3253 /* outType = */ getOutput().getType())
3254 .failed()) {
3255 return failure();
3256 }
3257
3258 const ShapeAdaptor inputShape(getInput1().getType());
3259 const ShapeAdaptor outputShape(getOutput().getType());
3260
3261 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
3262
3263 if (inputShape.hasRank() &&
3264 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
3265 return emitOpError() << "expected perms attribute to have size "
3266 << inputShape.getRank()
3267 << " (input rank) but got size "
3268 << constantPerms.size();
3269
3270 if (inputShape.hasRank() && outputShape.hasRank() &&
3271 inputShape.getRank() != outputShape.getRank())
3272 return emitOpError()
3273 << "expected input tensor rank to equal result tensor rank";
3274
3275 if (outputShape.hasRank() &&
3276 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
3277 return emitOpError() << "expected perms attribute to have size "
3278 << outputShape.getRank()
3279 << " (output rank) but got size "
3280 << constantPerms.size();
3281
3282 if (!llvm::all_of(constantPerms,
3283 [&constantPerms](int32_t s) {
3284 return s >= 0 &&
3285 static_cast<size_t>(s) < constantPerms.size();
3286 }) ||
3287 !isPermutationVector(llvm::map_to_vector(
3288 constantPerms, [](int32_t v) -> int64_t { return v; })))
3289 return emitOpError() << "expected valid permutation indices";
3290
3291 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
3292 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
3293 inputShape.getNumElements() != outputShape.getNumElements())
3294 return emitOpError() << "expected input1 and output to have same numbers "
3295 "of elements, got "
3296 << inputShape.getNumElements() << " and "
3297 << outputShape.getNumElements();
3298
3299 // Verify that the types of the input and output tensors are properly
3300 // permuted.
3301 if (inputShape.hasRank() && outputShape.hasRank()) {
3302 for (auto i = 0; i < outputShape.getRank(); i++) {
3303 if (inputShape.isDynamicDim(constantPerms[i]) ||
3304 outputShape.isDynamicDim(i))
3305 continue;
3306
3307 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
3308 return emitOpError()
3309 << "expected output tensor dim " << i << " to match "
3310 << "input dim " << constantPerms[i] << " with value of "
3311 << inputShape.getDimSize(constantPerms[i]);
3312 }
3313 }
3314
3315 return success();
3316}
3317
3318LogicalResult TransposeOp::reifyResultShapes(
3319 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3320
3321 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
3322
3323 Value input = getInput1();
3324 auto inputType = cast<TensorType>(input.getType());
3325
3326 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
3327 for (auto dim : transposePerms) {
3328 int32_t dimInInput = transposePerms[dim];
3329 if (inputType.isDynamicDim(dimInInput))
3330 returnedDims[dim] =
3331 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
3332 .getResult();
3333 else
3334 returnedDims[dim] =
3335 builder.getIndexAttr(inputType.getDimSize(dimInInput));
3336 }
3337
3338 reifiedReturnShapes.emplace_back(std::move(returnedDims));
3339 return success();
3340}
3341
3342LogicalResult tosa::GatherOp::inferReturnTypeComponents(
3343 MLIRContext *context, ::std::optional<Location> location,
3344 GatherOp::Adaptor adaptor,
3345 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3346 llvm::SmallVector<int64_t> outputShape;
3347 outputShape.resize(3, ShapedType::kDynamic);
3348
3349 ShapeAdaptor valuesShape(adaptor.getValues().getType());
3350 if (valuesShape.hasRank()) {
3351 outputShape[0] = valuesShape.getDimSize(0);
3352 outputShape[2] = valuesShape.getDimSize(2);
3353 }
3354
3355 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3356 if (indicesShape.hasRank()) {
3357 if (outputShape[0] == ShapedType::kDynamic)
3358 outputShape[0] = indicesShape.getDimSize(0);
3359 if (outputShape[1] == ShapedType::kDynamic)
3360 outputShape[1] = indicesShape.getDimSize(1);
3361 }
3362
3363 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3364 return success();
3365}
3366
3367LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
3368 MLIRContext *context, ::std::optional<Location> location,
3369 RowGatherBlockScaledOp::Adaptor adaptor,
3370 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3371 const auto values = adaptor.getValues();
3372 if (values.empty())
3373 return failure();
3374
3375 SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
3376 const ShapeAdaptor valuesShape(values.front().getType());
3377 if (valuesShape.hasRank()) {
3378 dataShape[0] = valuesShape.getDimSize(0);
3379 dataShape[2] = valuesShape.getDimSize(2);
3380 }
3381
3382 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3383 if (indicesShape.hasRank()) {
3384 if (dataShape[0] == ShapedType::kDynamic)
3385 dataShape[0] = indicesShape.getDimSize(0);
3386
3387 if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
3388 succeeded(rowCount) && rowCount.value() > 0) {
3389 const int64_t indicesW = indicesShape.getDimSize(1);
3390 if (ShapedType::isStatic(indicesW))
3391 dataShape[1] = indicesW * rowCount.value();
3392 }
3393 }
3394
3395 inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
3396 if (values.size() == 1)
3397 return success();
3398
3399 SmallVector<int64_t> scaleShape = dataShape;
3400 const uint32_t blockSize =
3401 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
3402 if (ShapedType::isStatic(dataShape[2]))
3403 scaleShape[2] = dataShape[2] / blockSize;
3404
3405 inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
3406 return success();
3407}
3408
3409LogicalResult tosa::GatherOp::verify() {
3410 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
3411 /* outType = */ getOutput().getType())
3412 .failed()) {
3413 return failure();
3414 }
3415
3416 const ShapeAdaptor valuesShape(getValues().getType());
3417 const ShapeAdaptor indicesShape(getIndices().getType());
3418 const ShapeAdaptor outputShape(getOutput().getType());
3419
3420 int64_t n = ShapedType::kDynamic;
3421 int64_t w = ShapedType::kDynamic;
3422 int64_t c = ShapedType::kDynamic;
3423
3424 if (valuesShape.hasRank()) {
3425 n = valuesShape.getDimSize(0);
3426 c = valuesShape.getDimSize(2);
3427 }
3428 if (indicesShape.hasRank()) {
3429 const int64_t indicesN = indicesShape.getDimSize(0);
3430 w = indicesShape.getDimSize(1);
3431 if (n == ShapedType::kDynamic)
3432 n = indicesN;
3433 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3434 return emitOpError() << "requires indices dimension 0 to have size " << n
3435 << ", got " << indicesN;
3436 }
3437 if (outputShape.hasRank()) {
3438 const int64_t outputN = outputShape.getDimSize(0);
3439 const int64_t outputW = outputShape.getDimSize(1);
3440 const int64_t outputC = outputShape.getDimSize(2);
3441 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3442 n != outputN)
3443 return emitOpError() << "requires output dimension 0 to have size " << n
3444 << ", got " << outputN;
3445
3446 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
3447 w != outputW)
3448 return emitOpError() << "requires output dimension 1 to have size " << w
3449 << ", got " << outputW;
3450 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3451 c != outputC)
3452 return emitOpError() << "requires output dimension 2 to have size " << c
3453 << ", got " << outputC;
3454 }
3455 return success();
3456}
3457
3458LogicalResult tosa::RowGatherBlockScaledOp::verify() {
3459 const OperandRange values = getValues();
3460 const ResultRange output = getOutput();
3461 if (values.empty() || values.size() > 2)
3462 return emitOpError()
3463 << "expects values tensor list length to be 1 or 2, got "
3464 << values.size();
3465 if (output.size() != values.size())
3466 return emitOpError()
3467 << "expects output tensor list length to match values tensor list "
3468 "length, got "
3469 << output.size() << " results for " << values.size()
3470 << " input tensors";
3471
3472 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
3473 if (values.size() == 1 && blockSize != 1)
3474 return emitOpError()
3475 << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
3476 "length is 1";
3477 if (values.size() == 2 && blockSize == 1)
3478 return emitOpError()
3479 << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
3480 "list length is 2";
3481
3482 if (failed(verifySameElementTypes(*this, values[0].getType(),
3483 output[0].getType(), "values[0]",
3484 "output[0]")))
3485 return failure();
3486 if (values.size() == 2 && failed(verifySameElementTypes(
3487 *this, values[1].getType(), output[1].getType(),
3488 "values[1]", "output[1]")))
3489 return failure();
3490
3491 if (auto rowCount = getConstantScalarIntValue(getRowCount());
3492 succeeded(rowCount) && rowCount.value() <= 0)
3493 return emitOpError() << "requires row_count to be > 0, got "
3494 << rowCount.value();
3495
3496 int64_t n = ShapedType::kDynamic;
3497 int64_t k = ShapedType::kDynamic;
3498 int64_t c = ShapedType::kDynamic;
3499 int64_t w = ShapedType::kDynamic;
3500 int64_t multiplesOfC = ShapedType::kDynamic;
3501
3502 const ShapeAdaptor valuesDataShape(values[0].getType());
3503 if (valuesDataShape.hasRank()) {
3504 n = valuesDataShape.getDimSize(0);
3505 k = valuesDataShape.getDimSize(1);
3506 c = valuesDataShape.getDimSize(2);
3507 }
3508
3509 if (ShapedType::isStatic(c) && c % blockSize != 0)
3510 return emitOpError() << "expects channels of values[0] (" << c
3511 << ") to be divisible by block_size (" << blockSize
3512 << ")";
3513
3514 const ShapeAdaptor indicesShape(getIndices().getType());
3515 if (indicesShape.hasRank()) {
3516 if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
3517 "indices", "batch")))
3518 return failure();
3519 w = indicesShape.getDimSize(1);
3520 }
3521
3522 const ShapeAdaptor outputDataShape(output[0].getType());
3523 if (outputDataShape.hasRank()) {
3524 if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
3525 "output[0]", "batch")) ||
3526 failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
3527 "output[0]", "channels")))
3528 return failure();
3529
3530 if (auto rowCount = getConstantScalarIntValue(getRowCount());
3531 succeeded(rowCount) && rowCount.value() > 0 &&
3532 ShapedType::isStatic(w)) {
3533 const int64_t expectedOutputRows = w * rowCount.value();
3534 if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
3535 outputDataShape.getDimSize(1) != expectedOutputRows)
3536 return emitOpError() << "requires output[0] dimension 1 to have size "
3537 << expectedOutputRows << ", got "
3538 << outputDataShape.getDimSize(1);
3539 }
3540 }
3541
3542 if (values.size() == 2) {
3543 const ShapeAdaptor valuesScaleShape(values[1].getType());
3544 if (valuesScaleShape.hasRank()) {
3545 if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
3546 "values[1]", "batch")) ||
3547 failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
3548 "values[1]", "rows")))
3549 return failure();
3550 multiplesOfC = valuesScaleShape.getDimSize(2);
3551 }
3552
3553 const ShapeAdaptor outputScaleShape(output[1].getType());
3554 if (outputScaleShape.hasRank()) {
3555 if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
3556 "output[1]", "batch")))
3557 return failure();
3558
3559 if (auto rowCount = getConstantScalarIntValue(getRowCount());
3560 succeeded(rowCount) && rowCount.value() > 0 &&
3561 ShapedType::isStatic(w)) {
3562 const int64_t expectedOutputRows = w * rowCount.value();
3563 if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
3564 outputScaleShape.getDimSize(1) != expectedOutputRows)
3565 return emitOpError() << "requires output[1] dimension 1 to have size "
3566 << expectedOutputRows << ", got "
3567 << outputScaleShape.getDimSize(1);
3568 }
3569
3570 if (ShapedType::isDynamic(multiplesOfC))
3571 multiplesOfC = outputScaleShape.getDimSize(2);
3572 else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
3573 multiplesOfC != outputScaleShape.getDimSize(2))
3574 return emitOpError()
3575 << "expected channels of output[1] to match size "
3576 << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
3577 }
3578
3579 if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
3580 multiplesOfC != c / blockSize)
3581 return emitOpError()
3582 << "expects channels of scale tensors to equal C/block_size (" << c
3583 << "/" << blockSize << "), got " << multiplesOfC;
3584 }
3585
3586 return success();
3587}
3588
3589LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3590 MLIRContext *context, ::std::optional<Location> location,
3591 ResizeOp::Adaptor adaptor,
3592 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3593 llvm::SmallVector<int64_t, 4> outputShape;
3594 outputShape.resize(4, ShapedType::kDynamic);
3595
3596 ShapeAdaptor inputShape(adaptor.getInput().getType());
3597 if (!inputShape.hasRank())
3598 return failure();
3599
3600 outputShape[0] = inputShape.getDimSize(0);
3601 outputShape[3] = inputShape.getDimSize(3);
3602 int64_t inputHeight = inputShape.getDimSize(1);
3603 int64_t inputWidth = inputShape.getDimSize(2);
3604
3605 if ((inputHeight == ShapedType::kDynamic) ||
3606 (inputWidth == ShapedType::kDynamic))
3607 return failure();
3608
3609 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3610 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
3611 scaleInt) ||
3612 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
3613 offsetInt) ||
3614 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
3615 borderInt)) {
3616 return failure();
3617 }
3618
3619 // Compute the output shape based on attributes: scale, offset, and border.
3620 const int64_t outputHeight =
3621 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3622 scaleInt[1]) +
3623 1;
3624
3625 const int64_t outputWidth =
3626 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3627 scaleInt[3]) +
3628 1;
3629
3630 if (outputHeight < 0 || outputWidth < 0) {
3631 return emitOptionalError(
3632 location,
3633 "calculated output height and width must be non-negative, "
3634 "got height = ",
3635 outputHeight, ", width = ", outputWidth);
3636 }
3637
3638 outputShape[1] = outputHeight;
3639 outputShape[2] = outputWidth;
3640 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3641 return success();
3642}
3643
3644LogicalResult tosa::ResizeOp::verify() {
3645 const Value input = getInput();
3646 const Value output = getOutput();
3647 const RankedTensorType inputType =
3648 llvm::dyn_cast<RankedTensorType>(input.getType());
3649 const RankedTensorType outputType =
3650 llvm::dyn_cast<RankedTensorType>(output.getType());
3651
3652 SmallVector<int64_t> scaleValues;
3653 SmallVector<int64_t> offsetValues;
3654 SmallVector<int64_t> borderValues;
3655 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
3656 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
3657 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
3658 // Skip following checks if shape is not constant
3659 return success();
3660 }
3661
3662 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
3663 return emitOpError("expect all scale values to be > 0, got ")
3664 << scaleValues;
3665
3666 const int64_t scaleYN = scaleValues[0];
3667 const int64_t scaleYD = scaleValues[1];
3668 const int64_t scaleXN = scaleValues[2];
3669 const int64_t scaleXD = scaleValues[3];
3670
3671 const int64_t offsetY = offsetValues[0];
3672 const int64_t offsetX = offsetValues[1];
3673
3674 const int64_t borderY = borderValues[0];
3675 const int64_t borderX = borderValues[1];
3676
3677 if (!inputType)
3678 return success();
3679 if (!outputType)
3680 return success();
3681
3682 const int64_t oh = outputType.getDimSize(1);
3683 const int64_t ow = outputType.getDimSize(2);
3684 const int64_t ih = inputType.getDimSize(1);
3685 const int64_t iw = inputType.getDimSize(2);
3686
3687 // Don't check with input height that could be broadcast (ih != 1)
3688 // since Linalg, a consumer of TOSA, expects broadcasting support
3689 // in resize to be available. Taking the cautious approach for now,
3690 // we can consider removing support for broadcasting later.
3691 if (ih != ShapedType::kDynamic && ih != 1) {
3692 const std::optional<int64_t> calculatedOutHeightMinusOne =
3693 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3694 if (!calculatedOutHeightMinusOne.has_value())
3695 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3696 "border_y ")
3697 << "to be wholly divisible by scale_y_d, got ((" << ih
3698 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3699 << ") / " << scaleYD;
3700 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3701 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3702 return emitOpError("calculated output height did not match expected: ")
3703 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3704 }
3705
3706 // Don't check with input width that could be broadcast (iw != 1)
3707 // since Linalg, a consumer of TOSA, expects broadcasting support
3708 // in resize to be available. Taking the cautious approach for now,
3709 // we can consider removing support for broadcasting later.
3710 if (iw != ShapedType::kDynamic && iw != 1) {
3711 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3712 const std::optional<int64_t> calculatedOutWidthMinusOne =
3713 idivCheck(scaledInWidth, scaleXD);
3714 if (!calculatedOutWidthMinusOne.has_value())
3715 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3716 "border_x ")
3717 << "to be wholly divisible by scale_x_d, got ((" << iw
3718 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3719 << ") / " << scaleXD;
3720 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3721 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3722 return emitOpError("calculated output width did not match expected: ")
3723 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3724 }
3725
3726 return success();
3727}
3728
3729LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3730 MLIRContext *context, ::std::optional<Location> location,
3731 ScatterOp::Adaptor adaptor,
3732 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3733 llvm::SmallVector<int64_t> outputShape;
3734 outputShape.resize(3, ShapedType::kDynamic);
3735
3736 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3737 if (valuesInShape.hasRank()) {
3738 outputShape[0] = valuesInShape.getDimSize(0);
3739 outputShape[1] = valuesInShape.getDimSize(1);
3740 outputShape[2] = valuesInShape.getDimSize(2);
3741 }
3742
3743 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3744 if (indicesShape.hasRank()) {
3745 if (outputShape[0] == ShapedType::kDynamic)
3746 outputShape[0] = indicesShape.getDimSize(0);
3747 }
3748
3749 ShapeAdaptor inputShape(adaptor.getInput().getType());
3750 if (inputShape.hasRank()) {
3751 if (outputShape[0] == ShapedType::kDynamic)
3752 outputShape[0] = inputShape.getDimSize(0);
3753 if (outputShape[2] == ShapedType::kDynamic)
3754 outputShape[2] = inputShape.getDimSize(2);
3755 }
3756
3757 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3758 return success();
3759}
3760
3761LogicalResult tosa::ScatterOp::verify() {
3762 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3763 /* outType = */ getValuesOut().getType())
3764 .failed() ||
3765 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3766 /* outType = */ getValuesOut().getType())
3767 .failed()) {
3768 return failure();
3769 }
3770
3771 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3772 const ShapeAdaptor indicesShape(getIndices().getType());
3773 const ShapeAdaptor inputShape(getInput().getType());
3774 const ShapeAdaptor outputShape(getValuesOut().getType());
3775
3776 int64_t n = ShapedType::kDynamic;
3777 int64_t k = ShapedType::kDynamic;
3778 int64_t w = ShapedType::kDynamic;
3779 int64_t c = ShapedType::kDynamic;
3780 if (valuesInShape.hasRank()) {
3781 n = valuesInShape.getDimSize(0);
3782 k = valuesInShape.getDimSize(1);
3783 c = valuesInShape.getDimSize(2);
3784 }
3785 if (indicesShape.hasRank()) {
3786 const int64_t indicesN = indicesShape.getDimSize(0);
3787 w = indicesShape.getDimSize(1);
3788 if (n == ShapedType::kDynamic)
3789 n = indicesN;
3790 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3791 return emitOpError() << "requires indices dimension 0 to have size " << n
3792 << ", got " << indicesN;
3793 }
3794 if (inputShape.hasRank()) {
3795 const int64_t inputN = inputShape.getDimSize(0);
3796 const int64_t inputW = inputShape.getDimSize(1);
3797 const int64_t inputC = inputShape.getDimSize(2);
3798 if (n == ShapedType::kDynamic)
3799 n = inputN;
3800 else if (inputN != ShapedType::kDynamic && n != inputN)
3801 return emitOpError() << "requires input dimension 0 to have size " << n
3802 << ", got " << inputN;
3803 if (w == ShapedType::kDynamic)
3804 w = inputW;
3805 else if (inputW != ShapedType::kDynamic && w != inputW)
3806 return emitOpError() << "requires input dimension 1 to have size " << w
3807 << ", got " << inputW;
3808
3809 if (c == ShapedType::kDynamic)
3810 c = inputC;
3811 else if (inputC != ShapedType::kDynamic && c != inputC)
3812 return emitOpError() << "requires input dimension 2 to have size " << c
3813 << ", got " << inputC;
3814 }
3815 if (outputShape.hasRank()) {
3816 const int64_t outputN = outputShape.getDimSize(0);
3817 const int64_t outputK = outputShape.getDimSize(1);
3818 const int64_t outputC = outputShape.getDimSize(2);
3819 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3820 n != outputN)
3821 return emitOpError() << "requires values_out dimension 0 to have size "
3822 << n << ", got " << outputN;
3823 if (k == ShapedType::kDynamic)
3824 k = outputK;
3825 else if (outputK != ShapedType::kDynamic && k != outputK)
3826 return emitOpError() << "requires values_out dimension 1 to have size "
3827 << k << ", got " << outputK;
3828 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3829 c != outputC)
3830 return emitOpError() << "requires values_out dimension 2 to have size "
3831 << c << ", got " << outputC;
3832 }
3833 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3834 return emitOpError() << "requires dimensions K >= W, got K=" << k
3835 << " and W=" << w;
3836
3837 return success();
3838}
3839
3840static LogicalResult ReduceInferReturnTypes(
3841 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3842 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3843 int64_t axisVal = axis.getValue().getSExtValue();
3844 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3845 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3846 return success();
3847 }
3848
3849 SmallVector<int64_t> outputShape;
3850 operandShape.getDims(outputShape);
3851 outputShape[axisVal] = 1;
3852 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3853 return success();
3854}
3855
3856#define COMPATIBLE_RETURN_TYPES(OP) \
3857 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3858 if (l.size() != r.size() || l.size() != 1) \
3859 return false; \
3860 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3861 return false; \
3862 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3863 }
3864
3865#define REDUCE_SHAPE_INFER(OP) \
3866 LogicalResult OP::inferReturnTypeComponents( \
3867 MLIRContext *context, ::std::optional<Location> location, \
3868 OP::Adaptor adaptor, \
3869 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3870 Type inputType = \
3871 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3872 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3873 const Properties &prop = adaptor.getProperties(); \
3874 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3875 inferredReturnShapes); \
3876 } \
3877 COMPATIBLE_RETURN_TYPES(OP)
3878
3879REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3880REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3881REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3882REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3883REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3884REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3885#undef REDUCE_SHAPE_INFER
3886COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3887#undef COMPATIBLE_RETURN_TYPES
3888
3889template <typename T>
3890static LogicalResult verifyReduceOp(T op) {
3891 // All TOSA reduce Ops have input, output and axis.
3892 TensorType inputType = op.getInput().getType();
3893 TensorType outputType = op.getOutput().getType();
3894 int32_t reduceAxis = op.getAxis();
3895
3896 if (reduceAxis < 0) {
3897 op.emitOpError("reduce axis must not be negative");
3898 return failure();
3899 }
3900 if (inputType.hasRank()) {
3901 int64_t inputRank = inputType.getRank();
3902 // We allow for a special case where the input/output shape has rank 0 and
3903 // axis is also 0.
3904 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3905 op.emitOpError("expect input tensor rank (")
3906 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3907 << ")";
3908 return failure();
3909 }
3910 }
3911 if (outputType.hasRank()) {
3912 int64_t outputRank = outputType.getRank();
3913 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3914 op.emitOpError(
3915 "expect output tensor rank to be equal to input tensor rank");
3916 return failure();
3917 }
3918 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3919 op.emitOpError("expect output tensor rank (")
3920 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3921 << ")";
3922 return failure();
3923 }
3924 // We can only verify the reduced dimension size to be 1 if this is not
3925 // the special case of output rank == 0.
3926 if (outputRank != 0) {
3927 auto outputShape = outputType.getShape();
3928 if (!outputType.isDynamicDim(reduceAxis) &&
3929 outputShape[reduceAxis] != 1) {
3930 op.emitOpError("expect reduced dimension size to be 1, got ")
3931 << outputShape[reduceAxis];
3932 return failure();
3933 }
3934 }
3935 }
3936 return success();
3937}
3938
3939LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3940LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3941LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3942LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3943LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3944LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3945
3946static LogicalResult NAryInferReturnTypes(
3947 const ValueShapeRange &operands,
3948 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3950 if (resolveBroadcastShape(operands, outShape).failed()) {
3951 inferredReturnShapes.push_back(ShapedTypeComponents());
3952 } else {
3953 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3954 }
3955 return success();
3956}
3957
3958#define NARY_SHAPE_INFER(OP) \
3959 LogicalResult OP::inferReturnTypeComponents( \
3960 MLIRContext *context, ::std::optional<Location> location, \
3961 ValueShapeRange operands, DictionaryAttr attributes, \
3962 PropertyRef properties, RegionRange regions, \
3963 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3964 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3965 }
3966
3967NARY_SHAPE_INFER(tosa::AbsOp)
3968NARY_SHAPE_INFER(tosa::AddOp)
3969NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3970NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3971NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3972NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3973NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3974NARY_SHAPE_INFER(tosa::CastOp)
3975NARY_SHAPE_INFER(tosa::CeilOp)
3976NARY_SHAPE_INFER(tosa::ClampOp)
3977NARY_SHAPE_INFER(tosa::ClzOp)
3978NARY_SHAPE_INFER(tosa::CosOp)
3979NARY_SHAPE_INFER(tosa::ExpOp)
3980NARY_SHAPE_INFER(tosa::FloorOp)
3981NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3982NARY_SHAPE_INFER(tosa::GreaterOp)
3983NARY_SHAPE_INFER(tosa::IdentityOp)
3984NARY_SHAPE_INFER(tosa::IntDivOp)
3985NARY_SHAPE_INFER(tosa::LogOp)
3986NARY_SHAPE_INFER(tosa::LogicalAndOp)
3987NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3988NARY_SHAPE_INFER(tosa::LogicalNotOp)
3989NARY_SHAPE_INFER(tosa::LogicalOrOp)
3990NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3991NARY_SHAPE_INFER(tosa::LogicalXorOp)
3992NARY_SHAPE_INFER(tosa::MaximumOp)
3993NARY_SHAPE_INFER(tosa::MinimumOp)
3994NARY_SHAPE_INFER(tosa::PowOp)
3995NARY_SHAPE_INFER(tosa::ReciprocalOp)
3996NARY_SHAPE_INFER(tosa::ReverseOp)
3997NARY_SHAPE_INFER(tosa::RsqrtOp)
3998NARY_SHAPE_INFER(tosa::SinOp)
3999NARY_SHAPE_INFER(tosa::SelectOp)
4000NARY_SHAPE_INFER(tosa::SubOp)
4001NARY_SHAPE_INFER(tosa::TanhOp)
4002NARY_SHAPE_INFER(tosa::ErfOp)
4003NARY_SHAPE_INFER(tosa::SigmoidOp)
4004#undef PRED_SHAPE_INFER
4005
4006LogicalResult tosa::NegateOp::inferReturnTypeComponents(
4007 MLIRContext *context, ::std::optional<Location> location,
4008 NegateOp::Adaptor adaptor,
4009 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4010 ShapeAdaptor inputShape(adaptor.getInput1().getType());
4011 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4012 return success();
4013}
4014
4015LogicalResult tosa::NegateOp::verify() {
4016 // Verify same element type
4017 const Type input1Type = getInput1().getType();
4018 const Type outputType = getOutput().getType();
4019 if (verifySameElementTypes(*this, input1Type, outputType).failed())
4020 return failure();
4021
4022 // Verify same shape
4023 const SmallVector<Type, 2> types = {input1Type, outputType};
4024 if (failed(verifyCompatibleShapes(types)))
4025 return emitOpError() << "requires the same shape for input1 and output";
4026
4027 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
4028 const Type input1ZpEType =
4029 getStorageElementTypeOrSelf(getInput1Zp().getType());
4030 if (input1EType != input1ZpEType) {
4031 return emitOpError("expect both input1 and its zero point are the same "
4032 "element type, got ")
4033 << input1EType << " and " << input1ZpEType;
4034 }
4035 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
4036 const Type outputZpEType =
4037 getStorageElementTypeOrSelf(getOutputZp().getType());
4038 if (outputEType != outputZpEType) {
4039 return emitOpError("expect both output and its zero point are the same "
4040 "element type, got ")
4041 << outputEType << " and " << outputZpEType;
4042 }
4043
4044 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
4045 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
4046 return failure();
4047
4048 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4049 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4050 return failure();
4051
4052 return success();
4053}
4054
4055static LogicalResult poolingInferReturnTypes(
4056 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
4058 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4059 llvm::SmallVector<int64_t> outputShape;
4060 outputShape.resize(4, ShapedType::kDynamic);
4061
4062 // We only know the rank if the input type is unranked.
4063 if (!inputShape.hasRank()) {
4064 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4065 return success();
4066 }
4067
4068 // Batch and number of channels are identical for pooling layer.
4069 outputShape[0] = inputShape.getDimSize(0);
4070 outputShape[3] = inputShape.getDimSize(3);
4071
4072 int64_t height = inputShape.getDimSize(1);
4073 int64_t width = inputShape.getDimSize(2);
4074
4075 if (ShapedType::isStatic(height)) {
4076 int64_t padded = height + pad[0] + pad[1] - kernel[0];
4077 outputShape[1] = padded / stride[0] + 1;
4078 }
4079
4080 if (ShapedType::isStatic(width)) {
4081 int64_t padded = width + pad[2] + pad[3] - kernel[1];
4082 outputShape[2] = padded / stride[1] + 1;
4083 }
4084
4085 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4086 return success();
4087}
4088
4089template <typename AdaptorT>
4091
4093protected:
4094 static void updateIfDynamic(int64_t &current, int64_t candidate) {
4095 if (ShapedType::isDynamic(current))
4096 current = candidate;
4097 }
4098};
4099
4100template <>
4101class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
4102 : public ConvInferShapeAdaptorBase {
4103public:
4104 explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
4105 : adaptor(adaptor) {}
4106
4108 SmallVectorImpl<int64_t> &inputSpatial) {
4109 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4110 if (!inputShape.hasRank())
4111 return;
4112
4113 const int64_t outputBatch = inputShape.getDimSize(0);
4114 const int64_t inputHeight = inputShape.getDimSize(1);
4115 const int64_t inputWidth = inputShape.getDimSize(2);
4116
4117 outputShape[0] = outputBatch;
4118 inputSpatial[0] = inputHeight;
4119 inputSpatial[1] = inputWidth;
4120 }
4121
4123 SmallVectorImpl<int64_t> &weightSpatial) {
4124 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4125 if (!weightShape.hasRank())
4126 return;
4127
4128 const int64_t outputChannels = weightShape.getDimSize(0);
4129 const int64_t kernelHeight = weightShape.getDimSize(1);
4130 const int64_t kernelWidth = weightShape.getDimSize(2);
4131
4132 outputShape[3] = outputChannels;
4133 weightSpatial[0] = kernelHeight;
4134 weightSpatial[1] = kernelWidth;
4135 }
4136
4137 int64_t getNumSpatialDims() const { return 2; }
4138 int64_t getOutputRank() const { return 4; }
4139
4141 SmallVector<int64_t> &strideValues,
4142 SmallVector<int64_t> &dilationValues) {
4143 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4144 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4145 dilationValues.assign(adaptor.getDilation().begin(),
4146 adaptor.getDilation().end());
4147 return success();
4148 }
4149
4150private:
4151 Conv2DOp::Adaptor adaptor;
4152};
4153
4154template <>
4155class ConvInferShapeAdaptor<Conv2DBlockScaledOp::Adaptor>
4156 : public ConvInferShapeAdaptorBase {
4157public:
4158 explicit ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
4159 : adaptor(adaptor) {}
4160
4162 SmallVectorImpl<int64_t> &inputSpatial) {
4163 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
4164 if (inputDataShape.hasRank()) {
4165 const int64_t outputBatch = inputDataShape.getDimSize(0);
4166 const int64_t inputHeight = inputDataShape.getDimSize(1);
4167 const int64_t inputWidth = inputDataShape.getDimSize(2);
4168
4169 outputShape[0] = outputBatch;
4170 inputSpatial[0] = inputHeight;
4171 inputSpatial[1] = inputWidth;
4172 }
4173
4174 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
4175 if (!inputScaleShape.hasRank())
4176 return;
4177
4178 const int64_t scaleBatch = inputScaleShape.getDimSize(0);
4179 const int64_t scaleHeight = inputScaleShape.getDimSize(1);
4180 const int64_t scaleWidth = inputScaleShape.getDimSize(2);
4181
4182 updateIfDynamic(outputShape[0], scaleBatch);
4183 updateIfDynamic(inputSpatial[0], scaleHeight);
4184 updateIfDynamic(inputSpatial[1], scaleWidth);
4185 }
4186
4188 SmallVectorImpl<int64_t> &weightSpatial) {
4189 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
4190 if (weightDataShape.hasRank()) {
4191 const int64_t outputChannels = weightDataShape.getDimSize(0);
4192 const int64_t kernelHeight = weightDataShape.getDimSize(1);
4193 const int64_t kernelWidth = weightDataShape.getDimSize(2);
4194
4195 outputShape[3] = outputChannels;
4196 weightSpatial[0] = kernelHeight;
4197 weightSpatial[1] = kernelWidth;
4198 }
4199
4200 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
4201 if (!weightScaleShape.hasRank())
4202 return;
4203
4204 const int64_t scaleOutputChannels = weightScaleShape.getDimSize(0);
4205 const int64_t scaleKernelHeight = weightScaleShape.getDimSize(1);
4206 const int64_t scaleKernelWidth = weightScaleShape.getDimSize(2);
4207
4208 updateIfDynamic(outputShape[3], scaleOutputChannels);
4209 updateIfDynamic(weightSpatial[0], scaleKernelHeight);
4210 updateIfDynamic(weightSpatial[1], scaleKernelWidth);
4211 }
4212
4213 int64_t getNumSpatialDims() const { return 2; }
4214 int64_t getOutputRank() const { return 4; }
4215
4217 SmallVector<int64_t> &strideValues,
4218 SmallVector<int64_t> &dilationValues) {
4219 if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(),
4220 padValues) ||
4221 !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4222 strideValues) ||
4223 !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
4224 dilationValues))
4225 return failure();
4226 return success();
4227 }
4228
4229private:
4230 Conv2DBlockScaledOp::Adaptor adaptor;
4231};
4232
4233template <>
4234class ConvInferShapeAdaptor<Conv3DOp::Adaptor>
4235 : public ConvInferShapeAdaptorBase {
4236public:
4237 explicit ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
4238 : adaptor(adaptor) {}
4239
4241 SmallVectorImpl<int64_t> &inputSpatial) {
4242 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4243 if (!inputShape.hasRank())
4244 return;
4245
4246 const int64_t outputBatch = inputShape.getDimSize(0);
4247 const int64_t inputDepth = inputShape.getDimSize(1);
4248 const int64_t inputHeight = inputShape.getDimSize(2);
4249 const int64_t inputWidth = inputShape.getDimSize(3);
4250
4251 outputShape[0] = outputBatch;
4252 inputSpatial[0] = inputDepth;
4253 inputSpatial[1] = inputHeight;
4254 inputSpatial[2] = inputWidth;
4255 }
4256
4258 SmallVectorImpl<int64_t> &weightSpatial) {
4259 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4260 if (!weightShape.hasRank())
4261 return;
4262
4263 const int64_t outputChannels = weightShape.getDimSize(0);
4264 const int64_t kernelDepth = weightShape.getDimSize(1);
4265 const int64_t kernelHeight = weightShape.getDimSize(2);
4266 const int64_t kernelWidth = weightShape.getDimSize(3);
4267
4268 outputShape[4] = outputChannels;
4269 weightSpatial[0] = kernelDepth;
4270 weightSpatial[1] = kernelHeight;
4271 weightSpatial[2] = kernelWidth;
4272 }
4273
4274 int64_t getNumSpatialDims() const { return 3; }
4275 int64_t getOutputRank() const { return 5; }
4276
4278 SmallVector<int64_t> &strideValues,
4279 SmallVector<int64_t> &dilationValues) {
4280 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4281 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4282 dilationValues.assign(adaptor.getDilation().begin(),
4283 adaptor.getDilation().end());
4284 return success();
4285 }
4286
4287private:
4288 Conv3DOp::Adaptor adaptor;
4289};
4290
4291template <typename AdaptorT>
4293 AdaptorT adaptor,
4294 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4295 ConvInferShapeAdaptor<AdaptorT> convShapeAdaptor(adaptor);
4296 llvm::SmallVector<int64_t> outputShape(convShapeAdaptor.getOutputRank(),
4297 ShapedType::kDynamic);
4298 llvm::SmallVector<int64_t> inputSpatial(convShapeAdaptor.getNumSpatialDims(),
4299 ShapedType::kDynamic);
4300 llvm::SmallVector<int64_t> weightSpatial(convShapeAdaptor.getNumSpatialDims(),
4301 ShapedType::kDynamic);
4302
4303 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
4304 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
4305
4306 const ShapeAdaptor biasShape = adaptor.getBias().getType();
4307 if (biasShape.hasRank()) {
4308 const int64_t biasSize = biasShape.getDimSize(0);
4309 if (biasSize != 1) {
4310 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
4311 outputShape[outputChannelDim] =
4312 ShapedType::isDynamic(outputShape[outputChannelDim])
4313 ? biasSize
4314 : outputShape[outputChannelDim];
4315 }
4316 }
4317
4318 SmallVector<int64_t> padValues;
4319 SmallVector<int64_t> strideValues;
4320 SmallVector<int64_t> dilationValues;
4321 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
4322 dilationValues))) {
4323 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4324 return success();
4325 }
4326
4327 for (int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
4328 if (!ShapedType::isStatic(inputSpatial[dim]) ||
4329 !ShapedType::isStatic(weightSpatial[dim]))
4330 continue;
4331 const int64_t inputSize =
4332 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
4333 const int64_t filterSize =
4334 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
4335 const int64_t unstridedResult = inputSize - filterSize + 1;
4336 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
4337 }
4338
4339 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4340 return success();
4341}
4342
4343LogicalResult Conv2DOp::inferReturnTypeComponents(
4344 MLIRContext *context, ::std::optional<Location> location,
4345 Conv2DOp::Adaptor adaptor,
4346 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4347 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4348}
4349
4350LogicalResult Conv2DOp::verify() {
4351 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4352 verifyConvOpErrorIf(*this).failed())
4353 return failure();
4354 return success();
4355}
4356
4357LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
4358 MLIRContext *context, ::std::optional<Location> location,
4359 Conv2DBlockScaledOp::Adaptor adaptor,
4360 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4361 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4362}
4363
4364LogicalResult Conv2DBlockScaledOp::verify() {
4365 if (failed(verifySameElementTypes(*this, getInputData().getType(),
4366 getWeightData().getType(), "input_data",
4367 "weight_data")) ||
4368 failed(verifySameElementTypes(*this, getInputScale().getType(),
4369 getWeightScale().getType(), "input_scale",
4370 "weight_scale")) ||
4371 failed(verifySameElementTypes(*this, getBias().getType(),
4372 getOutput().getType(), "bias", "output")))
4373 return failure();
4374
4375 // Verify input shape compatibility
4376 int64_t N = ShapedType::kDynamic;
4377 int64_t IH = ShapedType::kDynamic;
4378 int64_t IW = ShapedType::kDynamic;
4379 int64_t IC = ShapedType::kDynamic;
4380 int64_t multiplesOfIC = ShapedType::kDynamic;
4381 int64_t OC = ShapedType::kDynamic;
4382 int64_t KH = ShapedType::kDynamic;
4383 int64_t KW = ShapedType::kDynamic;
4384
4385 const ShapeAdaptor inputDataShape(getInputData().getType());
4386 if (inputDataShape.hasRank()) {
4387 N = inputDataShape.getDimSize(0);
4388 IH = inputDataShape.getDimSize(1);
4389 IW = inputDataShape.getDimSize(2);
4390 IC = inputDataShape.getDimSize(3);
4391 }
4392
4393 const ShapeAdaptor inputScaleShape(getInputScale().getType());
4394 if (inputScaleShape.hasRank()) {
4395 if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
4396 "input_scale", "batch size")) ||
4397 failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
4398 "input_scale", "input height")) ||
4399 failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
4400 "input_scale", "input width")))
4401 return failure();
4402 multiplesOfIC = inputScaleShape.getDimSize(3);
4403 }
4404
4405 const ShapeAdaptor weightDataShape(getWeightData().getType());
4406 if (weightDataShape.hasRank()) {
4407 OC = weightDataShape.getDimSize(0);
4408 KH = weightDataShape.getDimSize(1);
4409 KW = weightDataShape.getDimSize(2);
4410 if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
4411 "weight_data", "input channels")))
4412 return failure();
4413 }
4414
4415 const ShapeAdaptor weightScaleShape(getWeightScale().getType());
4416 if (weightScaleShape.hasRank()) {
4417 if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
4418 "weight_scale", "output channels")) ||
4419 failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
4420 "weight_scale", "kernel height")) ||
4421 failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
4422 "weight_scale", "kernel width")) ||
4423 failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
4424 weightScaleShape.getDimSize(3),
4425 "weight_scale", "input channel blocks")))
4426 return failure();
4427 }
4428
4429 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
4430 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4431 return emitOpError("expect block size to be 32, got ") << blockSize;
4432 // Verify IC is a multiple of block size
4433 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
4434 return emitOpError("expect IC to be a multiple of block size, got IC=")
4435 << IC << ", block_size=" << blockSize;
4436
4437 // Verify multiplesOfIC is IC / block size
4438 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
4439 multiplesOfIC != IC / blockSize)
4440 return emitOpError(
4441 "expect scale operands dimension 2 to equal IC/block_size (")
4442 << IC << "/" << blockSize << ")"
4443 << ", got " << multiplesOfIC;
4444
4445 // Verify pad/stride/dilation values
4446 SmallVector<int64_t> padValues;
4447 if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
4448 if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
4449 return emitOpError("expect all padding values to be >= 0, got ")
4450 << padValues;
4451 }
4452
4453 SmallVector<int64_t> strideValues;
4454 if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
4455 if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
4456 return emitOpError("expect all stride values to be >= 1, got ")
4457 << strideValues;
4458 }
4459
4460 SmallVector<int64_t> dilationValues;
4461 if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
4462 dilationValues)) {
4463 if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
4464 return emitOpError("expect all dilation values to be >= 1, got ")
4465 << dilationValues;
4466 }
4467
4468 // Verify output shape compatibility
4469 const ShapeAdaptor outputShape(getOutput().getType());
4470 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
4471 outputShape.hasRank()) {
4472 if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
4473 padValues[0], padValues[1], strideValues[0],
4474 dilationValues[0], "height", "y", "top",
4475 "bottom")) ||
4476 failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
4477 padValues[2], padValues[3], strideValues[1],
4478 dilationValues[1], "width", "x", "left",
4479 "right")))
4480 return failure();
4481 }
4482
4483 // Verify bias
4484 const ShapeAdaptor biasShape(getBias().getType());
4485 if (biasShape.hasRank() && outputShape.hasRank()) {
4486 const int64_t biasChannels = biasShape.getDimSize(0);
4487 const int64_t outputChannels =
4488 outputShape.getDimSize(outputShape.getRank() - 1);
4489 if (biasChannels == ShapedType::kDynamic ||
4490 outputChannels == ShapedType::kDynamic)
4491 // Skip following checks if biasChannels or outputChannels is dynamic dim
4492 return success();
4493
4494 if (biasChannels != outputChannels && biasChannels != 1)
4495 return emitOpError(
4496 "bias channels expected to be equal to output channels (")
4497 << outputChannels << ") or 1, got " << biasChannels;
4498 }
4499
4500 return success();
4501}
4502
4503LogicalResult Conv3DOp::inferReturnTypeComponents(
4504 MLIRContext *context, ::std::optional<Location> location,
4505 Conv3DOp::Adaptor adaptor,
4506 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4507 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4508}
4509
4510LogicalResult Conv3DOp::verify() {
4511 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4512 verifyConvOpErrorIf(*this).failed())
4513 return failure();
4514 return success();
4515}
4516
4517LogicalResult AvgPool2dOp::inferReturnTypeComponents(
4518 MLIRContext *context, ::std::optional<Location> location,
4519 AvgPool2dOp::Adaptor adaptor,
4520 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4521 ShapeAdaptor inputShape(adaptor.getInput().getType());
4522 const Properties &prop = adaptor.getProperties();
4523 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
4524 inferredReturnShapes);
4525}
4526
4527LogicalResult AvgPool2dAdaptiveOp::inferReturnTypeComponents(
4528 MLIRContext *context, ::std::optional<Location> location,
4529 AvgPool2dAdaptiveOp::Adaptor adaptor,
4530 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4531 ShapeAdaptor inputShape(adaptor.getInput().getType());
4532
4533 llvm::SmallVector<int64_t> kernelValues;
4534 llvm::SmallVector<int64_t> strideValues;
4535 llvm::SmallVector<int64_t> padValues;
4536 if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
4537 kernelValues) &&
4538 tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4539 strideValues) &&
4540 tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
4541 return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
4542 padValues, inferredReturnShapes);
4543 }
4544
4545 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4546 if (inputShape.hasRank()) {
4547 // Keep N & C as pooling only changes H & W.
4548 outputShape[0] = inputShape.getDimSize(0);
4549 outputShape[3] = inputShape.getDimSize(3);
4550 }
4551
4552 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4553 return success();
4554}
4555
4556LogicalResult MaxPool2dOp::inferReturnTypeComponents(
4557 MLIRContext *context, ::std::optional<Location> location,
4558 MaxPool2dOp::Adaptor adaptor,
4559 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4560 ShapeAdaptor inputShape(adaptor.getInput().getType());
4561 const Properties &prop = adaptor.getProperties();
4562 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
4563 inferredReturnShapes);
4564}
4565
4566LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
4567 MLIRContext *context, ::std::optional<Location> location,
4568 MaxPool2dAdaptiveOp::Adaptor adaptor,
4569 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4570 ShapeAdaptor inputShape(adaptor.getInput().getType());
4571
4572 llvm::SmallVector<int64_t> kernelValues;
4573 llvm::SmallVector<int64_t> strideValues;
4574 llvm::SmallVector<int64_t> padValues;
4575 if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
4576 kernelValues) &&
4577 tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4578 strideValues) &&
4579 tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
4580 return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
4581 padValues, inferredReturnShapes);
4582 }
4583
4584 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4585 if (inputShape.hasRank()) {
4586 outputShape[0] = inputShape.getDimSize(0);
4587 outputShape[3] = inputShape.getDimSize(3);
4588 }
4589 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4590 return success();
4591}
4592
4593LogicalResult MaxPool2dOp::verify() {
4594 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
4595 /* outType = */ getOutput().getType())))
4596 return failure();
4597
4598 if (failed(verifyPoolingOp(*this)))
4599 return failure();
4600
4601 return success();
4602}
4603
4604LogicalResult MaxPool2dAdaptiveOp::verify() {
4605 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
4606 /* outType = */ getOutput().getType())))
4607 return failure();
4608
4609 AdaptivePoolingConstShapeValues values;
4611
4612 if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
4613 values.pad, getInput(), getOutput())))
4614 return failure();
4615
4616 return success();
4617}
4618
4619LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
4620 MLIRContext *context, ::std::optional<Location> location,
4621 DepthwiseConv2DOp::Adaptor adaptor,
4622 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4623 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4624
4625 int64_t inputWidth = ShapedType::kDynamic;
4626 int64_t inputHeight = ShapedType::kDynamic;
4627 int64_t inputChannels = ShapedType::kDynamic;
4628
4629 int64_t weightWidth = ShapedType::kDynamic;
4630 int64_t weightHeight = ShapedType::kDynamic;
4631 int64_t depthChannels = ShapedType::kDynamic;
4632
4633 // Input shape describes input width/height and batch.
4634 ShapeAdaptor inputShape(adaptor.getInput().getType());
4635 if (inputShape.hasRank()) {
4636 outputShape[0] = inputShape.getDimSize(0);
4637 inputHeight = inputShape.getDimSize(1);
4638 inputWidth = inputShape.getDimSize(2);
4639 inputChannels = inputShape.getDimSize(3);
4640 }
4641
4642 // Weight shapes describes the filter width/height and the output channels.
4643 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4644 if (weightShape.hasRank()) {
4645 weightHeight = weightShape.getDimSize(0);
4646 weightWidth = weightShape.getDimSize(1);
4647 inputChannels = ShapedType::isDynamic(inputChannels)
4648 ? weightShape.getDimSize(2)
4649 : inputChannels;
4650 depthChannels = weightShape.getDimSize(3);
4651 }
4652
4653 // If both inputChannels and depthChannels are available we can determine
4654 // the output channels.
4655 if (ShapedType::isStatic(inputChannels) &&
4656 ShapedType::isStatic(depthChannels)) {
4657 outputShape[3] = inputChannels * depthChannels;
4658 }
4659
4660 // Bias shape can describe the output channels.
4661 ShapeAdaptor biasShape(adaptor.getBias().getType());
4662 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4663 int64_t bc = biasShape.getDimSize(0);
4664 if (bc != ShapedType::kDynamic && bc != 1)
4665 outputShape[3] = bc;
4666 }
4667
4668 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4669 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4670 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4671
4672 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4673 int64_t inputSize = inputHeight + padding[0] + padding[1];
4674 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4675 int64_t unstridedResult = inputSize - filterSize + 1;
4676 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4677 }
4678
4679 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4680 int64_t inputSize = inputWidth + padding[2] + padding[3];
4681 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4682 int64_t unstridedResult = inputSize - filterSize + 1;
4683 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4684 }
4685
4686 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4687 return success();
4688}
4689
4690LogicalResult DepthwiseConv2DOp::verify() {
4691 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4692 verifyConvOpErrorIf(*this).failed())
4693 return failure();
4694 return success();
4695}
4696
4697LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4698 MLIRContext *context, ::std::optional<Location> location,
4699 TransposeConv2DOp::Adaptor adaptor,
4700 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4701 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4702
4703 int64_t inputWidth = ShapedType::kDynamic;
4704 int64_t inputHeight = ShapedType::kDynamic;
4705 int64_t weightWidth = ShapedType::kDynamic;
4706 int64_t weightHeight = ShapedType::kDynamic;
4707
4708 // Input shape describes input width/height and batch.
4709 ShapeAdaptor inputShape(adaptor.getInput().getType());
4710 if (inputShape.hasRank()) {
4711 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4712 ? inputShape.getDimSize(0)
4713 : outputShape[0];
4714 inputHeight = inputShape.getDimSize(1);
4715 inputWidth = inputShape.getDimSize(2);
4716 }
4717
4718 // Weight shapes describes the filter width/height and the output channels.
4719 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4720 if (weightShape.hasRank()) {
4721 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4722 ? weightShape.getDimSize(0)
4723 : outputShape[3];
4724 weightHeight = weightShape.getDimSize(1);
4725 weightWidth = weightShape.getDimSize(2);
4726 }
4727
4728 // Bias shape can describe the output channels.
4729 ShapeAdaptor biasShape(adaptor.getBias().getType());
4730 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4731 int64_t bc = biasShape.getDimSize(0);
4732 if (bc != ShapedType::kDynamic && bc != 1)
4733 outputShape[3] = bc;
4734 }
4735
4736 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4737 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4738
4739 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4740 int64_t calculateSize =
4741 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4742 outputShape[1] =
4743 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4744 }
4745
4746 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4747 int64_t calculateSize =
4748 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4749 outputShape[2] =
4750 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4751 }
4752
4753 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4754 return success();
4755}
4756
4757LogicalResult TransposeConv2DOp::verify() {
4758 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
4759 return failure();
4760
4761 const llvm::ArrayRef<int64_t> strides = getStride();
4762 const int64_t strideY = strides[0];
4763 const int64_t strideX = strides[1];
4764
4765 if (strideY < 1 || strideX < 1)
4766 return emitOpError("expect all stride values to be >= 1, got [")
4767 << strides << "]";
4768
4769 const auto checkPadAgainstKernelDim =
4770 [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4771 llvm::StringRef kernelDimName) -> LogicalResult {
4772 if (padValue <= -kernelDimSize)
4773 return emitOpError("expected ")
4774 << padName << " > -" << kernelDimName << ", but got: " << padName
4775 << "=" << padValue << " and " << kernelDimName << "="
4776 << kernelDimSize;
4777 return success();
4778 };
4779
4780 const llvm::ArrayRef<int64_t> padding = getOutPad();
4781 const int64_t outPadTop = padding[0];
4782 const int64_t outPadBottom = padding[1];
4783 const int64_t outPadLeft = padding[2];
4784 const int64_t outPadRight = padding[3];
4785
4786 const auto weightType =
4787 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
4788
4789 if (weightType) {
4790 const int64_t kernelHeight = weightType.getDimSize(1);
4791 if (ShapedType::isStatic(kernelHeight)) {
4792 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4793 "out_pad_top", "KH")))
4794 return failure();
4795
4796 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4797 "out_pad_bottom", "KH")))
4798 return failure();
4799 }
4800
4801 const int64_t kernelWidth = weightType.getDimSize(2);
4802 if (ShapedType::isStatic(kernelWidth)) {
4803 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4804 "out_pad_left", "KW")))
4805 return failure();
4806
4807 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4808 "out_pad_right", "KW")))
4809 return failure();
4810 }
4811 }
4812
4813 // Rest of the checks depend on the output type being a RankedTensorType
4814 const auto outputType =
4815 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
4816 if (!outputType)
4817 return success();
4818
4819 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
4820 if (inputType && weightType) {
4821 const int64_t inputHeight = inputType.getDimSize(1);
4822 const int64_t kernelHeight = weightType.getDimSize(1);
4823 const int64_t outputHeight = outputType.getDimSize(1);
4824
4825 if (ShapedType::isStatic(inputHeight) &&
4826 ShapedType::isStatic(outputHeight)) {
4827 if (outputHeight !=
4828 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4829 return emitOpError(
4830 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4831 "+ out_pad_top + out_pad_bottom + KH, but got ")
4832 << outputHeight << " != (" << inputHeight << " - 1) * "
4833 << strideY << " + " << outPadTop << " + " << outPadBottom
4834 << " + " << kernelHeight;
4835 }
4836
4837 const int64_t inputWidth = inputType.getDimSize(2);
4838 const int64_t kernelWidth = weightType.getDimSize(2);
4839 const int64_t outputWidth = outputType.getDimSize(2);
4840
4841 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4842 if (outputWidth !=
4843 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4844 return emitOpError(
4845 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4846 "+ out_pad_left + out_pad_right + KW, but got ")
4847 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
4848 << " + " << outPadLeft << " + " << outPadRight << " + "
4849 << kernelWidth;
4850 }
4851 }
4852
4853 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
4854
4855 if (!biasType)
4856 return success();
4857
4858 const int64_t biasChannels = biasType.getDimSize(0);
4859
4860 // Skip further checks if bias is dynamic
4861 if (biasChannels == ShapedType::kDynamic)
4862 return success();
4863
4864 const int64_t outputChannels = outputType.getDimSize(3);
4865 if (!ShapedType::isDynamic(outputChannels) &&
4866 biasChannels != outputChannels && biasChannels != 1)
4867 return emitOpError(
4868 "bias channels expected to be equal to output channels (")
4869 << outputChannels << ") or 1, got " << biasChannels;
4870
4871 return success();
4872}
4873
4874LogicalResult RescaleOp::verify() {
4875 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
4876 auto inputElementType =
4877 getStorageElementTypeOrSelf(inputType.getElementType());
4878 if (!mlir::isa<IntegerType>(inputElementType)) {
4879 emitOpError("expect input to have integer element type, got ")
4880 << inputElementType;
4881 return failure();
4882 }
4883
4884 const auto outputType = llvm::cast<ShapedType>(getOutput().getType());
4885 auto outputElementType =
4886 getStorageElementTypeOrSelf(outputType.getElementType());
4887 if (!mlir::isa<IntegerType>(outputElementType)) {
4888 emitOpError("expect output to have integer element type, got ")
4889 << outputElementType;
4890 return failure();
4891 }
4892
4893 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
4894 .failed())
4895 return failure();
4896
4897 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
4898 .failed())
4899 return failure();
4900
4901 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4902 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4903 return failure();
4904
4905 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4906 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4907 return failure();
4908
4909 const auto multiplierType = llvm::cast<ShapedType>(getMultiplier().getType());
4910 // multiplier element type must be i32 for scale32 = true
4911 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4912 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
4913 << multiplierType.getElementType();
4914 return failure();
4915 }
4916
4917 // multiplier element type must be i16 for scale32 = false
4918 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4920 "expect i16 element type for multiplier for scale32=false, got ")
4921 << multiplierType.getElementType();
4922 return failure();
4923 }
4924
4925 if (!inputType.hasRank())
4926 return success();
4927
4928 // multiplier/shift must have shape = {numChannels},
4929 // where numChannel is 1 if per_channel = false
4930 // otherwise numChannel is dimension in input shape's last axis
4931 int64_t numChannels = 1;
4932 if (getPerChannel()) {
4933 if (inputType.getRank() < 1) {
4934 emitOpError("requires input to be at least rank 1 when per_channel is "
4935 "true, but got rank ")
4936 << inputType.getRank();
4937 return failure();
4938 }
4939 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4940 }
4941
4942 if (outputType.hasRank()) {
4944 getOperation(), outputType, inputType.getShape())))
4945 return failure();
4946 }
4947
4948 if (multiplierType.hasRank()) {
4949 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4950 // multiplier input has rank 1 by dialect definition
4951 if (multiplierShape[0] != ShapedType::kDynamic &&
4952 multiplierShape[0] != numChannels) {
4953 emitOpError("expect shape of { ")
4954 << numChannels << " } for multiplier input, got { "
4955 << multiplierShape[0] << " }";
4956 return failure();
4957 }
4958 }
4959
4960 const auto shiftType = llvm::cast<ShapedType>(getShift().getType());
4961 if (shiftType.hasRank()) {
4962 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4963 // shift input has rank 1 by dialect definition
4964 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4965 emitOpError("expect shape of { ")
4966 << numChannels << " } for shift input, got { " << shiftShape[0]
4967 << " }";
4968 return failure();
4969 }
4970 }
4971
4972 return success();
4973}
4974
4975LogicalResult RescaleOp::inferReturnTypeComponents(
4976 MLIRContext *context, ::std::optional<Location> location,
4977 RescaleOp::Adaptor adaptor,
4978 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4979 ShapeAdaptor inputShape(adaptor.getInput().getType());
4980 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4981 return success();
4982}
4983
4984LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4985 MLIRContext *context, ::std::optional<Location> location,
4986 CastFromBlockScaledOp::Adaptor adaptor,
4987 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4988 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4989 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4990 return success();
4991}
4992
4993LogicalResult CastFromBlockScaledOp::verify() {
4994 const Type inputDataType = getInputData().getType();
4995 const Type outputDataType = getResult().getType();
4996 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4997 return emitOpError() << "require compatible shapes for input_data ("
4998 << inputDataType << ") and " << "output_data ("
4999 << outputDataType << ")";
5000
5001 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5002
5003 if (inputDataShape.hasRank()) {
5004 const unsigned int blockSize =
5005 BlockSizeAttr::getBlockSizeValue(getBlockSize());
5006 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5007 return emitOpError("expect block size to be 32, got ") << blockSize;
5008 const int64_t inputDataLastDim =
5009 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
5010 if (inputDataLastDim % blockSize != 0)
5011 return emitOpError() << "expect last dimension of input_data ("
5012 << inputDataLastDim
5013 << ") to be divisible by block_size (" << blockSize
5014 << ")";
5015
5016 const Type inputScaleType = getInputScale().getType();
5017 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
5018
5019 if (inputScaleShape.hasRank()) {
5020 SmallVector<int64_t> inputDataDims, inputScaleDims;
5021 inputDataShape.getDims(inputDataDims);
5022 inputScaleShape.getDims(inputScaleDims);
5023
5024 if (inputDataDims.size() != inputScaleDims.size() ||
5026 ArrayRef<int64_t>(inputDataDims).drop_back(1),
5027 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
5028 return emitOpError()
5029 << "require compatible shapes for input_data (" << inputDataType
5030 << ") and " << "input_scale (" << inputScaleType
5031 << ") except for the last dimension";
5032
5033 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
5034 inputScaleDims.back()};
5035 if (ShapedType::isStatic(inputDataLastDim) &&
5036 failed(verifyCompatibleDims(dimsToCheck)))
5037 return emitOpError()
5038 << "expect last dimension of input_scale ("
5039 << inputScaleDims.back()
5040 << ") to be equal to last dimension of input_data / block_size ("
5041 << inputDataDims.back() / blockSize << ")";
5042 }
5043 }
5044
5045 return success();
5046}
5047
5048LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
5049 MLIRContext *context, ::std::optional<Location> location,
5050 CastToBlockScaledOp::Adaptor adaptor,
5051 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5052 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5053 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5054 if (!inputShape.hasRank())
5055 return success();
5056
5057 // Calculate output_scale shape if ranked input provided
5058 SmallVector<int64_t> outputScaleShape;
5059 inputShape.getDims(outputScaleShape);
5060 const int64_t lastDimLoc = inputShape.getRank() - 1;
5061 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
5062 if (ShapedType::isStatic(lastDimSize)) {
5063 const unsigned int blockSize =
5064 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
5065 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
5066 }
5067 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
5068 return success();
5069}
5070
5071LogicalResult CastToBlockScaledOp::verify() {
5072 const Type inputDataType = getInputData().getType();
5073 const Type outputDataType = getResult(0).getType();
5074 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
5075 return emitOpError() << "require compatible shapes for input_data ("
5076 << inputDataType << ") and " << "output_data ("
5077 << outputDataType << ")";
5078
5079 const unsigned int blockSize =
5080 BlockSizeAttr::getBlockSizeValue(getBlockSize());
5081 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5082 return emitOpError("expect block size to be 32, got ") << blockSize;
5083 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5084 if (inputDataShape.hasRank()) {
5085 const int64_t inputDataLastDim =
5086 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
5087 if (ShapedType::isStatic(inputDataLastDim) &&
5088 inputDataLastDim % blockSize != 0)
5089 return emitOpError() << "expect last dimension of input_data ("
5090 << inputDataLastDim
5091 << ") to be divisible by block_size (" << blockSize
5092 << ")";
5093 }
5094
5095 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
5096 const Type outputScaleType = getResult(1).getType();
5097 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
5098 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
5099 SmallVector<int64_t> outputDataDims, outputScaleDims;
5100 outputDataShape.getDims(outputDataDims);
5101 outputScaleShape.getDims(outputScaleDims);
5102
5103 if (outputDataDims.size() != outputScaleDims.size() ||
5105 ArrayRef<int64_t>(outputDataDims).drop_back(1),
5106 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
5107 return emitOpError() << "require compatible shapes for output_data ("
5108 << outputDataType << ") and " << "output_scale ("
5109 << outputScaleType
5110 << ") except for the last dimension";
5111
5112 const int64_t outputDataLastDim = outputDataDims.back();
5113 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
5114 outputScaleDims.back()};
5115 if (ShapedType::isStatic(outputDataLastDim) &&
5116 failed(verifyCompatibleDims(dimsToCheck)))
5117 return emitOpError()
5118 << "expect last dimension of output_scale ("
5119 << outputScaleDims.back()
5120 << ") to be equal to last dimension of output_data / block_size ("
5121 << outputDataDims.back() / blockSize << ")";
5122 }
5123
5124 return success();
5125}
5126
5127LogicalResult IfOp::inferReturnTypeComponents(
5128 MLIRContext *context, ::std::optional<Location> location,
5129 IfOp::Adaptor adaptor,
5130 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5131 llvm::SmallVector<tosa::YieldOp> yieldOps;
5132 for (Region *region : adaptor.getRegions()) {
5133 for (auto &block : *region)
5134 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5135 yieldOps.push_back(returnOp);
5136 }
5137
5138 if (yieldOps.empty())
5139 return failure();
5140
5141 // Get the initial type information for the yield op.
5142 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5143 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5144 for (auto operand : yieldOps.front().getOperands()) {
5145 resultKnowledge.push_back(
5146 ValueKnowledge::getKnowledgeFromType(operand.getType()));
5147 }
5148
5149 for (auto yieldOp : yieldOps) {
5150 if (resultKnowledge.size() != yieldOp.getNumOperands())
5151 return failure();
5152
5153 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5154 int32_t index = it.index();
5155 auto meet = ValueKnowledge::meet(
5156 resultKnowledge[index],
5157 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
5158 if (!meet)
5159 continue;
5160 resultKnowledge[index] = meet;
5161 }
5162 }
5163
5164 for (const ValueKnowledge &result : resultKnowledge) {
5165 inferredReturnShapes.push_back(result.getShapedTypeComponents());
5166 }
5167
5168 return success();
5169}
5170
5171LogicalResult WhileOp::inferReturnTypeComponents(
5172 MLIRContext *context, ::std::optional<Location> location,
5173 WhileOp::Adaptor adaptor,
5174 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5175 llvm::SmallVector<tosa::YieldOp> yieldOps;
5176 for (auto &block : adaptor.getBodyGraph())
5177 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5178 yieldOps.push_back(returnOp);
5179
5180 // TOSA's while must have a tosa.yield as its terminator. If not found this
5181 // tosa.while is invalid.
5182 if (yieldOps.empty())
5183 return failure();
5184
5185 // Get the initial type information from the operand types.
5186 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5187 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5188 for (auto operand : yieldOps.front().getOperands()) {
5189 resultKnowledge.push_back(
5190 ValueKnowledge::getKnowledgeFromType(operand.getType()));
5191 }
5192
5193 for (auto yieldOp : yieldOps) {
5194 if (resultKnowledge.size() != yieldOp.getNumOperands())
5195 return failure();
5196
5197 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5198 int32_t index = it.index();
5199 if (auto meet = ValueKnowledge::meet(
5200 resultKnowledge[index],
5201 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
5202 resultKnowledge[index] = meet;
5203 }
5204 }
5205 }
5206
5207 for (const ValueKnowledge &result : resultKnowledge) {
5208 inferredReturnShapes.push_back(result.getShapedTypeComponents());
5209 }
5210
5211 return success();
5212}
5213
5214std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
5215 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
5216 return llvm::to_vector<4>(vt.getShape());
5217 return std::nullopt;
5218}
5219
5221 Block::BlockArgListType blocksArgs,
5222 ValueRange initializers,
5223 StringRef prefix = "") {
5224 assert(blocksArgs.size() == initializers.size() &&
5225 "expected same length of arguments and initializers");
5226 if (initializers.empty())
5227 return;
5228
5229 parser << prefix << '(';
5230 llvm::interleaveComma(
5231 llvm::zip(blocksArgs, initializers), parser,
5232 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
5233 parser << ")";
5234}
5235
5236// parse and print of IfOp refer to the implementation of SCF dialect.
5237ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
5238 // Create the regions for 'then'.
5239 result.regions.reserve(2);
5240 Region *thenRegion = result.addRegion();
5241 Region *elseRegion = result.addRegion();
5242
5243 OpAsmParser::UnresolvedOperand cond;
5244
5245 if (parser.parseOperand(cond))
5246 return failure();
5247
5248 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5249 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5250
5251 // Parse the optional block arguments
5252 OptionalParseResult listResult =
5253 parser.parseOptionalAssignmentList(regionArgs, operands);
5254 if (listResult.has_value() && failed(listResult.value()))
5255 return failure();
5256
5257 // Parse a colon.
5258 if (failed(parser.parseColon()))
5259 return parser.emitError(parser.getCurrentLocation(),
5260 "expected type for condition operand");
5261
5262 // Parse the type of the condition operand
5263 Type condType;
5264 if (failed(parser.parseType(condType)))
5265 return parser.emitError(parser.getCurrentLocation(),
5266 "expected type for condition operand");
5267
5268 // Resolve operand with provided type
5269 if (failed(parser.resolveOperand(cond, condType, result.operands)))
5270 return failure();
5271
5272 // Parse optional block arg types
5273 if (listResult.has_value()) {
5274 FunctionType functionType;
5275
5276 if (failed(parser.parseType(functionType)))
5277 return parser.emitError(parser.getCurrentLocation())
5278 << "expected list of types for block arguments "
5279 << "followed by arrow type and list of return types";
5280
5281 result.addTypes(functionType.getResults());
5282
5283 if (functionType.getNumInputs() != operands.size()) {
5284 return parser.emitError(parser.getCurrentLocation())
5285 << "expected as many input types as operands " << "(expected "
5286 << operands.size() << " got " << functionType.getNumInputs()
5287 << ")";
5288 }
5289
5290 // Resolve input operands.
5291 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
5292 parser.getCurrentLocation(),
5293 result.operands)))
5294 return failure();
5295 } else {
5296 // Parse optional results type list.
5297 if (parser.parseOptionalArrowTypeList(result.types))
5298 return failure();
5299 }
5300
5301 // Parse the 'then' region.
5302 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
5303 return failure();
5304
5305 // If we find an 'else' keyword then parse the 'else' region.
5306 if (!parser.parseOptionalKeyword("else")) {
5307 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
5308 return failure();
5309 }
5310
5311 // Parse the optional attribute list.
5312 if (parser.parseOptionalAttrDict(result.attributes))
5313 return failure();
5314 return success();
5315}
5316
5317void IfOp::print(OpAsmPrinter &p) {
5318 p << " " << getCondition();
5319
5320 printInitializationList(p, getThenGraph().front().getArguments(),
5321 getInputList(), " ");
5322 p << " : ";
5323 p << getCondition().getType();
5324
5325 if (!getInputList().empty()) {
5326 p << " (";
5327 llvm::interleaveComma(getInputList().getTypes(), p);
5328 p << ")";
5329 }
5330 p.printArrowTypeList(getResultTypes());
5331 p << " ";
5332
5333 p.printRegion(getThenGraph());
5334
5335 // Print the 'else' regions if it exists and has a block.
5336 auto &elseRegion = getElseGraph();
5337 if (!elseRegion.empty()) {
5338 p << " else ";
5339 p.printRegion(elseRegion);
5340 }
5341
5342 p.printOptionalAttrDict((*this)->getAttrs());
5343}
5344
5345LogicalResult IfOp::verify() {
5346 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
5347 "'then_graph' arguments", getInputList(),
5348 "'input_list'")
5349 .failed())
5350 return failure();
5351
5352 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
5353 "'else_graph' arguments", getInputList(),
5354 "'input_list'")
5355 .failed())
5356 return failure();
5357
5358 // MLIR will verify the absence of the terminator for us if otherwise.
5359 if (getThenGraph().front().mightHaveTerminator()) {
5360 auto thenYield =
5361 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
5362 if (thenYield && errorIfTypeOrShapeMismatch(
5363 *this, thenYield.getInputs(), "'then_graph' results",
5364 getOutputList(), "'output_list'")
5365 .failed())
5366 return failure();
5367 }
5368
5369 // MLIR will verify the absence of the terminator for us if otherwise.
5370 if (getElseGraph().front().mightHaveTerminator()) {
5371 auto elseYield =
5372 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
5373 if (elseYield && errorIfTypeOrShapeMismatch(
5374 *this, elseYield.getInputs(), "'else_graph' results",
5375 getOutputList(), "'output_list'")
5376 .failed())
5377 return failure();
5378 }
5379
5380 auto condType = getCondition().getType();
5381 if (errorIfShapeNotSizeOne(*this, condType).failed())
5382 return emitOpError() << "'condition' must be a size 1 tensor, got "
5383 << condType;
5384
5385 return success();
5386}
5387
5388LogicalResult WhileOp::verify() {
5389 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
5390 getOutputList(), "'output_list'")
5391 .failed())
5392 return failure();
5393
5394 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
5395 "'cond_graph' arguments", getInputList(),
5396 "'input_list'")
5397 .failed())
5398 return failure();
5399
5400 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
5401 "'body_graph' arguments", getInputList(),
5402 "'input_list'")
5403 .failed())
5404 return failure();
5405
5406 if (getBodyGraph().front().mightHaveTerminator()) {
5407 auto bodyYield =
5408 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
5409 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
5410 "'body_graph' results",
5411 getInputList(), "'input_list'")
5412 .failed())
5413 return failure();
5414 }
5415
5416 // Condition block output must be a single element tensor with a single bool
5417 // value.
5418 if (!getCondGraph().front().mightHaveTerminator())
5419 return success();
5420
5421 auto condYield =
5422 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
5423 if (!condYield)
5424 return success();
5425
5426 if (condYield.getInputs().size() != 1)
5427 return emitOpError() << "require 'cond_graph' only have one result";
5428
5429 auto condOutType = condYield.getInputs()[0].getType();
5430 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
5431 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
5432 << condOutType;
5433
5434 if (!getElementTypeOrSelf(condOutType).isInteger(1))
5435 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
5436 << condOutType;
5437
5438 return success();
5439}
5440
5441LogicalResult ReverseOp::verify() {
5442 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
5443 /* outType = */ getOutput().getType())
5444 .failed())
5445 return failure();
5446 TensorType inputType = getInput1().getType();
5447 TensorType outputType = getOutput().getType();
5448 int32_t reverseAxis = getAxis();
5449
5450 if (reverseAxis < 0)
5451 return emitOpError("expected non-negative reverse axis");
5452 if (inputType.hasRank()) {
5453 int64_t inputRank = inputType.getRank();
5454 // We allow for a special case where the input/output shape has rank 0 and
5455 // axis is also 0.
5456 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
5457 return emitOpError("expect input tensor rank (")
5458 << inputRank << ") to be larger than reverse axis (" << reverseAxis
5459 << ")";
5460 }
5461 if (outputType.hasRank()) {
5462 int64_t outputRank = outputType.getRank();
5463 if (inputType.hasRank() && outputRank != inputType.getRank())
5464 return emitOpError(
5465 "expect output tensor rank to be equal to input tensor rank");
5466 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
5467 return emitOpError("expect output tensor rank (")
5468 << outputRank << ") to be larger than reverse axis ("
5469 << reverseAxis << ")";
5470 }
5471 return success();
5472}
5473
5474LogicalResult tosa::SelectOp::verify() {
5475 // verify input2 and input3 have same element type as output
5476 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
5477 /* outType = */ getOutput().getType())
5478 .failed() ||
5479 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
5480 /* outType = */ getOutput().getType())
5481 .failed()) {
5482 return failure();
5483 }
5484 // verify input1 has element type of bool
5485 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
5486 if (!predicateType) {
5487 return emitOpError("expect shaped tensor for input1, got ")
5488 << getInput1().getType();
5489 }
5490 auto predicateElementType = predicateType.getElementType();
5491 if (!predicateElementType.isInteger(1)) {
5492 return emitOpError("expect element type of bool for input1, got ")
5493 << predicateElementType;
5494 }
5495
5496 return success();
5497}
5498
5499LogicalResult tosa::VariableReadOp::verify() {
5500 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
5501 .failed())
5502 return failure();
5503
5504 return success();
5505}
5506
5507LogicalResult tosa::VariableWriteOp::verify() {
5508 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
5509 .failed())
5510 return failure();
5511
5512 return success();
5513}
5514
5515// parse and print of WhileOp refer to the implementation of SCF dialect.
5516ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
5517 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5518 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5519 Region *cond = result.addRegion();
5520 Region *body = result.addRegion();
5521
5522 OptionalParseResult listResult =
5523 parser.parseOptionalAssignmentList(regionArgs, operands);
5524 if (listResult.has_value() && failed(listResult.value()))
5525 return failure();
5526
5527 FunctionType functionType;
5528 SMLoc typeLoc = parser.getCurrentLocation();
5529 if (failed(parser.parseColonType(functionType)))
5530 return failure();
5531
5532 result.addTypes(functionType.getResults());
5533
5534 if (functionType.getNumInputs() != operands.size()) {
5535 return parser.emitError(typeLoc)
5536 << "expected as many input types as operands " << "(expected "
5537 << operands.size() << " got " << functionType.getNumInputs() << ")";
5538 }
5539
5540 // Resolve input operands.
5541 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
5542 parser.getCurrentLocation(),
5543 result.operands)))
5544 return failure();
5545
5546 // Propagate the types into the region arguments.
5547 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
5548 regionArgs[i].type = functionType.getInput(i);
5549
5550 return failure(parser.parseRegion(*cond, regionArgs) ||
5551 parser.parseKeyword("do") || parser.parseRegion(*body) ||
5552 parser.parseOptionalAttrDictWithKeyword(result.attributes));
5553}
5554
5555void WhileOp::print(OpAsmPrinter &parser) {
5556 printInitializationList(parser, getCondGraph().front().getArguments(),
5557 getInputList(), " ");
5558 parser << " : ";
5559 parser.printFunctionalType(getInputList().getTypes(),
5560 getResults().getTypes());
5561 parser << ' ';
5562 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
5563 parser << " do ";
5564 parser.printRegion(getBodyGraph());
5565 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
5566}
5567
5568// Create a rank-1 const tensor for zero point of the source tensor.
5569std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
5570 Location loc,
5571 Type srcElemType,
5572 int64_t zp) {
5573 srcElemType = getStorageElementTypeOrSelf(srcElemType);
5574 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
5575 if (llvm::isa<FloatType>(srcElemType)) {
5576 auto zpAttr = DenseElementsAttr::get(
5577 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
5578 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5579 }
5580 if (llvm::isa<IntegerType>(srcElemType)) {
5581 auto zpAttr =
5582 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
5583 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5584 }
5585 llvm::errs() << "zero point is not allowed for unsupported data types\n";
5586 return std::nullopt;
5587}
5588
5589//===----------------------------------------------------------------------===//
5590// TOSA Shape and Shape Operators Helper functions.
5591//===----------------------------------------------------------------------===//
5592
5594 return mlir::isa<tosa::shapeType>(t);
5595}
5596
5597LogicalResult
5598mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
5599 int rank) {
5600 if (rank < 0)
5601 return emitError() << "invalid rank (must be >= 0): " << rank;
5602 return success();
5603}
5604
5606 for (auto v : op->getOperands()) {
5607 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
5608 Operation *definingOp = v.getDefiningOp();
5609 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
5610 return op->emitOpError("shape operand is not compile time resolvable");
5611 }
5612 }
5613 }
5614 return success();
5615}
5616
5617LogicalResult
5619 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
5620 return failure();
5621
5622 // delegate function that returns rank of shape type
5623 auto getRank = [](const Type type) {
5624 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
5625 };
5626 auto operandTypes = op->getOperandTypes();
5627 auto resultTypes = op->getResultTypes();
5628
5629 auto rank = getRank(*op->getOperandTypes().begin());
5630 for (auto type : operandTypes) {
5631 if (getRank(type) != rank) {
5632 return op->emitOpError("operands don't have matching ranks");
5633 }
5634 }
5635 for (auto type : resultTypes) {
5636 if (getRank(type) != rank) {
5637 return op->emitOpError("result shape has different rank than operands");
5638 }
5639 }
5640 return success();
5641}
5642
5643//===----------------------------------------------------------------------===//
5644// TOSA Shape Operators verify functions.
5645//===----------------------------------------------------------------------===//
5646
5647LogicalResult tosa::ConstShapeOp::verify() {
5648 // check one dimensional rank
5649 auto valuesRank = getValues().getType().getRank();
5650 if (valuesRank != 1)
5651 return emitOpError("expect elements in attribute values with rank 1");
5652 // check that number of elements in values attr equal to rank of result shape
5653 auto count = getValues().getNumElements();
5654 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
5655 if (count != rank && (count != 1 || rank != 0)) {
5656 return emitOpError("expect number of elements in attribute values (")
5657 << count << ") to be equal to the rank (" << rank
5658 << ") for the result shape type";
5659 }
5660 return success();
5661}
5662
5663LogicalResult tosa::DimOp::verify() {
5664 const tosa::shapeType outShapeType =
5665 cast<tosa::shapeType>(getResult().getType());
5666 if (outShapeType.getRank() != 1)
5667 return emitOpError("expect output shape type to contain one element, got ")
5668 << outShapeType;
5669
5670 const ShapeAdaptor inputType(getInput1().getType());
5671 if (inputType.hasRank()) {
5672 const int64_t inputRank = inputType.getRank();
5673 const int64_t axis = getAxisAttr().getInt();
5674 if (axis < 0 || axis >= inputRank)
5675 return emitOpError("expect axis to be in the range [0, ")
5676 << inputRank << "), got " << axis;
5677 }
5678 return success();
5679}
5680
5681LogicalResult tosa::ConcatShapeOp::verify() {
5682 const tosa::shapeType outShapeType =
5683 cast<tosa::shapeType>(getResult().getType());
5684 const int64_t outputRank = outShapeType.getRank();
5685 const Operation::operand_range inputList = getInput();
5686
5687 if (inputList.size() == 0)
5688 return emitOpError("requires at least one input shape");
5689
5690 if (llvm::any_of(inputList, [](Value v) {
5691 return cast<tosa::shapeType>(v.getType()).getRank() == 0;
5692 }))
5693 return emitOpError("requires all inputs shapes have a rank greater than 0");
5694
5695 const int64_t inputsRank =
5696 llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
5697 const tosa::shapeType inShapeType =
5698 cast<tosa::shapeType>(input.getType());
5699 return acc + inShapeType.getRank();
5700 });
5701 if (outputRank != inputsRank)
5702 return emitOpError("requires output shape rank to be equal to the sum of "
5703 "the input shape ranks (")
5704 << inputsRank << "), got " << outputRank;
5705
5706 return success();
5707}
5708
5709LogicalResult tosa::SliceShapeOp::verify() {
5710 std::optional<int32_t> start;
5711 DenseIntElementsAttr startAttr;
5712 if (matchPattern(getStart(), m_Constant(&startAttr)))
5713 start = startAttr.getValues<int32_t>()[0];
5714 if (start && start.value() < 0)
5715 return emitOpError("expected non-negative start index, got ")
5716 << start.value();
5717
5718 std::optional<int32_t> size;
5719 DenseIntElementsAttr sizeAttr;
5720 if (matchPattern(getSize(), m_Constant(&sizeAttr)))
5721 size = sizeAttr.getValues<int32_t>()[0];
5722 if (size && size.value() <= 0)
5723 return emitOpError("expected positive size, got ") << size.value();
5724
5725 if (!size)
5726 return success();
5727
5728 const tosa::shapeType outShapeType =
5729 cast<tosa::shapeType>(getResult().getType());
5730 const int64_t outputRank = outShapeType.getRank();
5731 if (outputRank != size)
5732 return emitOpError(
5733 "expected output type size to be equal to size attribute, got ")
5734 << outputRank << " vs " << size.value();
5735
5736 if (!start)
5737 return success();
5738
5739 const tosa::shapeType inShapeType =
5740 cast<tosa::shapeType>(getInput().getType());
5741 const int64_t inputRank = inShapeType.getRank();
5742 const int64_t sliceSize = start.value() + size.value();
5743 if (sliceSize > inputRank)
5744 return emitOpError("expected start + size to be less than or equal to "
5745 "input shape rank (")
5746 << inputRank << "), got " << sliceSize;
5747
5748 return success();
5749}
5750
5751//===----------------------------------------------------------------------===//
5752// TOSA Attribute Definitions.
5753//===----------------------------------------------------------------------===//
5754
5755#define GET_ATTRDEF_CLASSES
5756#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5757
5758//===----------------------------------------------------------------------===//
5759// TOSA Type Definitions.
5760//===----------------------------------------------------------------------===//
5761#define GET_TYPEDEF_CLASSES
5762#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5763
5764//===----------------------------------------------------------------------===//
5765// TOSA Operator Definitions.
5766//===----------------------------------------------------------------------===//
5767
5768#define GET_OP_CLASSES
5769#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static std::string diag(const llvm::Value &value)
static void printShapeToDiagnostic(InFlightDiagnostic &diag, ArrayRef< int64_t > shape)
Definition TosaOps.cpp:653
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1428
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:1054
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:4292
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:139
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3840
static void buildAvgPool2dAdaptiveOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseI64ArrayAttr kernel, DenseI64ArrayAttr stride, DenseI64ArrayAttr pad, TypeAttr accType)
This builder mirrors avg_pool2d quant-info handling and materializes kernel/stride/pad as const_shape...
Definition TosaOps.cpp:1501
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:595
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:1018
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3865
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:719
static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op)
Definition TosaOps.cpp:1217
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:1027
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:4055
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1588
static LogicalResult verifyPoolingOpImpl(Operation *op, ArrayRef< int64_t > kernel, ArrayRef< int64_t > strides, ArrayRef< int64_t > padding, Value input, Value output)
Definition TosaOps.cpp:1119
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:578
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1602
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
Definition TosaOps.cpp:682
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3890
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3958
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:3162
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1406
static void extractAdaptivePoolingConstShapeOperands(T op, AdaptivePoolingConstShapeValues &values)
Definition TosaOps.cpp:1275
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:873
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:3081
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp
Definition TosaOps.cpp:1268
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
Definition TosaOps.cpp:638
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:824
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3946
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:584
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3856
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1646
static LogicalResult verifyOutputShapeCompatibleWithExpected(Operation *op, ShapedType outputType, ArrayRef< int64_t > expectedShape, StringRef outputName="output")
Definition TosaOps.cpp:666
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1548
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1382
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1457
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:976
static FailureOr< int64_t > getConstantScalarIntValue(Value val)
Definition TosaOps.cpp:3107
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
Definition TosaOps.cpp:1632
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:3120
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1211
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1732
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static void updateIfDynamic(int64_t &current, int64_t candidate)
Definition TosaOps.cpp:4094
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4187
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4216
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4161
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
Definition TosaOps.cpp:4158
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4107
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4122
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
Definition TosaOps.cpp:4104
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4140
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4257
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
Definition TosaOps.cpp:4237
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4240
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4277
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:197
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:129
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ResultRange result_range
Support result iteration.
Definition Operation.h:436
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OperandRange operand_range
Definition Operation.h:397
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:5618
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:5605
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:228
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:253
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:630
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:5569
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:5593
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:615
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45